00001 /****************************************************************************************************************/ 00002 /* */ 00003 /* OpenNN: Open Neural Networks Library */ 00004 /* www.opennn.cimne.com */ 00005 /* */ 00006 /* T R A I N I N G S T R A T E G Y C L A S S H E A D E R */ 00007 /* */ 00008 /* Roberto Lopez */ 00009 /* International Center for Numerical Methods in Engineering (CIMNE) */ 00010 /* Technical University of Catalonia (UPC) */ 00011 /* Barcelona, Spain */ 00012 /* E-mail: rlopez@cimne.upc.edu */ 00013 /* */ 00014 /****************************************************************************************************************/ 00015 00016 #ifndef __TRAININGSTRATEGY_H__ 00017 #define __TRAININGSTRATEGY_H__ 00018 00019 // OpenNN includes 00020 00021 #include "../performance_functional/performance_functional.h" 00022 00023 #include "training_algorithm.h" 00024 00025 // TinyXml includes 00026 00027 #include "../../parsers/tinyxml/tinyxml.h" 00028 00029 namespace OpenNN 00030 { 00031 00039 00040 class TrainingStrategy 00041 { 00042 00043 public: 00044 00045 // DEFAULT CONSTRUCTOR 00046 00047 explicit TrainingStrategy(void); 00048 00049 // GENERAL CONSTRUCTOR 00050 00051 explicit TrainingStrategy(PerformanceFunctional*); 00052 00053 // XML CONSTRUCTOR 00054 00055 explicit TrainingStrategy(TiXmlElement*); 00056 00057 // FILE CONSTRUCTOR 00058 00059 explicit TrainingStrategy(const std::string&); 00060 00061 // DESTRUCTOR 00062 00063 virtual ~TrainingStrategy(void); 00064 00065 // ENUMERATIONS 00066 00068 00069 enum TrainingAlgorithmType 00070 { 00071 NONE, 00072 RANDOM_SEARCH, 00073 EVOLUTIONARY_ALGORITHM, 00074 GRADIENT_DESCENT, 00075 CONJUGATE_GRADIENT, 00076 QUASI_NEWTON_METHOD, 00077 LEVENBERG_MARQUARDT_ALGORITHM, 00078 NEWTON_METHOD, 00079 USER_TRAINING_ALGORITHM 00080 }; 00081 00082 // STRUCTURES 00083 00086 00087 struct Results 00088 { 00090 00091 TrainingAlgorithm::Results* initialization_training_algorithm_results_pointer; 00092 00094 00095 TrainingAlgorithm::Results* main_training_algorithm_results_pointer; 00096 00098 00099 TrainingAlgorithm::Results* refinement_training_algorithm_results_pointer; 00100 00101 00102 void save(const std::string&) const; 00103 }; 00104 00105 // METHODS 00106 00107 // Get methods 00108 00109 PerformanceFunctional* get_performance_functional_pointer(void) const; 00110 00111 TrainingAlgorithm* get_initialization_training_algorithm_pointer(void) const; 00112 TrainingAlgorithm* get_main_training_algorithm_pointer(void) const; 00113 TrainingAlgorithm* get_refinement_training_algorithm_pointer(void) const; 00114 00115 const TrainingAlgorithmType& get_initialization_training_algorithm_type(void) const; 00116 const TrainingAlgorithmType& get_main_training_algorithm_type(void) const; 00117 const TrainingAlgorithmType& get_refinement_training_algorithm_type(void) const; 00118 00119 std::string write_initialization_training_algorithm_type(void) const; 00120 std::string write_main_training_algorithm_type(void) const; 00121 std::string write_refinement_training_algorithm_type(void) const; 00122 00123 const bool& get_initialization_training_algorithm_flag(void); 00124 const bool& get_main_training_algorithm_flag(void); 00125 const bool& get_refinement_training_algorithm_flag(void); 00126 00127 const bool& get_display(void) const; 00128 00129 // Set methods 00130 00131 void set(void); 00132 void set(PerformanceFunctional*); 00133 virtual void set_default(void); 00134 00135 void set_performance_functional_pointer(PerformanceFunctional*); 00136 00137 void set_initialization_training_algorithm_pointer(TrainingAlgorithm*); 00138 void set_main_training_algorithm_pointer(TrainingAlgorithm*); 00139 void set_refinement_training_algorithm_pointer(TrainingAlgorithm*); 00140 00141 void set_initialization_training_algorithm_type(const TrainingAlgorithmType&); 00142 void set_main_training_algorithm_type(const TrainingAlgorithmType&); 00143 void set_refinement_training_algorithm_type(const TrainingAlgorithmType&); 00144 00145 void set_initialization_training_algorithm_type(const std::string&); 00146 void set_main_training_algorithm_type(const std::string&); 00147 void set_refinement_training_algorithm_type(const std::string&); 00148 00149 void set_initialization_training_algorithm_flag(const bool&); 00150 void set_main_training_algorithm_flag(const bool&); 00151 void set_refinement_training_algorithm_flag(const bool&); 00152 00153 void set_display(const bool&); 00154 00155 // Pointer methods 00156 00157 void construct_initialization_training_algorithm(const TrainingAlgorithmType&); 00158 void construct_main_training_algorithm(const TrainingAlgorithmType&); 00159 void construct_refinement_training_algorithm(const TrainingAlgorithmType&); 00160 00161 void destruct_initialization_training_algorithm(void); 00162 void destruct_main_training_algorithm(void); 00163 void destruct_refinement_training_algorithm(void); 00164 00165 // Training methods 00166 00167 // This method trains a neural network which has a performance functional associated. 00168 00169 Results perform_training(void); 00170 00171 // Serialization methods 00172 00173 std::string to_string(void) const; 00174 00175 void print(void) const; 00176 00177 virtual TiXmlElement* to_XML(void) const; 00178 virtual void from_XML(TiXmlElement*); 00179 00180 void save(const std::string&) const; 00181 void load(const std::string&); 00182 00183 protected: 00184 00186 00187 PerformanceFunctional* performance_functional_pointer; 00188 00190 00191 TrainingAlgorithm* initialization_training_algorithm_pointer; 00192 00194 00195 TrainingAlgorithm* main_training_algorithm_pointer; 00196 00198 00199 TrainingAlgorithm* refinement_training_algorithm_pointer; 00200 00201 00203 00204 TrainingAlgorithmType initialization_training_algorithm_type; 00205 00207 00208 TrainingAlgorithmType main_training_algorithm_type; 00209 00211 00212 TrainingAlgorithmType refinement_training_algorithm_type; 00213 00214 00216 00217 bool initialization_training_algorithm_flag; 00218 00220 00221 bool main_training_algorithm_flag; 00222 00224 00225 bool refinement_training_algorithm_flag; 00226 00227 00229 00230 bool display; 00231 00232 }; 00233 00234 } 00235 00236 #endif 00237 00238 00239 // OpenNN: Open Neural Networks Library. 00240 // Copyright (C) 2005-2012 Roberto Lopez 00241 // 00242 // This library is free software; you can redistribute it and/or 00243 // modify it under the terms of the GNU Lesser General Public 00244 // License as published by the Free Software Foundation; either 00245 // version 2.1 of the License, or any later version. 00246 // 00247 // This library is distributed in the hope that it will be useful, 00248 // but WITHOUT ANY WARRANTY; without even the implied warranty of 00249 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 00250 // Lesser General Public License for more details. 00251 00252 // You should have received a copy of the GNU Lesser General Public 00253 // License along with this library; if not, write to the Free Software 00254 // Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA 00255