00001 /****************************************************************************************************************/ 00002 /* */ 00003 /* OpenNN: Open Neural Networks Library */ 00004 /* www.opennn.cimne.com */ 00005 /* */ 00006 /* C O N J U G A T E G R A D I E N T C L A S S H E A D E R */ 00007 /* */ 00008 /* Roberto Lopez */ 00009 /* International Center for Numerical Methods in Engineering (CIMNE) */ 00010 /* Technical University of Catalonia (UPC) */ 00011 /* Barcelona, Spain */ 00012 /* E-mail: rlopez@cimne.upc.edu */ 00013 /* */ 00014 /****************************************************************************************************************/ 00015 00016 00017 #ifndef __CONJUGATEGRADIENT_H__ 00018 #define __CONJUGATEGRADIENT_H__ 00019 00020 // OpenNN includes 00021 00022 #include "../performance_functional/performance_functional.h" 00023 00024 #include "training_algorithm.h" 00025 #include "training_rate_algorithm.h" 00026 00027 00028 namespace OpenNN 00029 { 00030 00034 00035 class ConjugateGradient : public TrainingAlgorithm 00036 { 00037 00038 public: 00039 00040 // ENUMERATIONS 00041 00043 00044 enum TrainingDirectionMethod{PR, FR}; 00045 00046 // DEFAULT CONSTRUCTOR 00047 00048 explicit ConjugateGradient(void); 00049 00050 00051 // GENERAL CONSTRUCTOR 00052 00053 explicit ConjugateGradient(PerformanceFunctional*); 00054 00055 00056 // XML CONSTRUCTOR 00057 00058 explicit ConjugateGradient(TiXmlElement*); 00059 00060 00061 // DESTRUCTOR 00062 00063 virtual ~ConjugateGradient(void); 00064 00065 00066 // STRUCTURES 00067 00071 00072 struct ConjugateGradientResults : public TrainingAlgorithm::Results 00073 { 00074 // TRAINING HISTORY 00075 00077 00078 Vector< Vector<double> > parameters_history; 00079 00081 00082 Vector<double> parameters_norm_history; 00083 00085 00086 Vector<double> evaluation_history; 00087 00089 00090 Vector<double> generalization_evaluation_history; 00091 00093 00094 Vector< Vector<double> > gradient_history; 00095 00097 00098 Vector<double> gradient_norm_history; 00099 00101 00102 Vector< Vector<double> > training_direction_history; 00103 00105 00106 Vector<double> training_rate_history; 00107 00109 00110 Vector<double> elapsed_time_history; 00111 00112 // FINAL VALUES 00113 00115 00116 Vector<double> final_parameters; 00117 00119 00120 double final_parameters_norm; 00121 00123 00124 double final_evaluation; 00125 00127 00128 double final_generalization_evaluation; 00129 00131 00132 Vector<double> final_gradient; 00133 00135 00136 double final_gradient_norm; 00137 00139 00140 Vector<double> final_training_direction; 00141 00143 00144 double final_training_rate; 00145 00147 00148 double elapsed_time; 00149 00150 void resize_training_history(const unsigned int&); 00151 std::string to_string(void) const; 00152 }; 00153 00154 00155 // METHODS 00156 00157 // Get methods 00158 00159 const TrainingRateAlgorithm& get_training_rate_algorithm(void) const; 00160 TrainingRateAlgorithm* get_training_rate_algorithm_pointer(void); 00161 00162 // Training operators 00163 00164 const TrainingDirectionMethod& get_training_direction_method(void) const; 00165 std::string write_training_direction_method(void) const; 00166 00167 // Training parameters 00168 00169 const double& get_warning_parameters_norm(void) const; 00170 const double& get_warning_gradient_norm(void) const; 00171 const double& get_warning_training_rate(void) const; 00172 00173 const double& get_error_parameters_norm(void) const; 00174 const double& get_error_gradient_norm(void) const; 00175 const double& get_error_training_rate(void) const; 00176 00177 // Stopping criteria 00178 00179 const double& get_minimum_parameters_increment_norm(void) const; 00180 00181 const double& get_minimum_performance_increase(void) const; 00182 const double& get_performance_goal(void) const; 00183 const unsigned int& get_maximum_generalization_evaluation_decreases(void) const; 00184 const double& get_gradient_norm_goal(void) const; 00185 00186 const unsigned int& get_maximum_epochs_number(void) const; 00187 const double& get_maximum_time(void) const; 00188 00189 // Reserve training history 00190 00191 const bool& get_reserve_parameters_history(void) const; 00192 const bool& get_reserve_parameters_norm_history(void) const; 00193 00194 const bool& get_reserve_evaluation_history(void) const; 00195 const bool& get_reserve_generalization_evaluation_history(void) const; 00196 const bool& get_reserve_gradient_history(void) const; 00197 const bool& get_reserve_gradient_norm_history(void) const; 00198 00199 const bool& get_reserve_training_direction_history(void) const; 00200 const bool& get_reserve_training_rate_history(void) const; 00201 const bool& get_reserve_elapsed_time_history(void) const; 00202 00203 // Utilities 00204 00205 const unsigned int& get_display_period(void) const; 00206 00207 // Set methods 00208 00209 void set_default(void); 00210 00211 // Training operators 00212 00213 void set_training_direction_method(const TrainingDirectionMethod&); 00214 void set_training_direction_method(const std::string&); 00215 00216 // Training parameters 00217 00218 void set_warning_parameters_norm(const double&); 00219 void set_warning_gradient_norm(const double&); 00220 void set_warning_training_rate(const double&); 00221 00222 void set_error_parameters_norm(const double&); 00223 void set_error_gradient_norm(const double&); 00224 void set_error_training_rate(const double&); 00225 00226 // Stopping criteria 00227 00228 void set_minimum_parameters_increment_norm(const double&); 00229 00230 void set_performance_goal(const double&); 00231 void set_minimum_performance_increase(const double&); 00232 void set_maximum_generalization_evaluation_decreases(const unsigned int&); 00233 void set_gradient_norm_goal(const double&); 00234 00235 void set_maximum_epochs_number(const unsigned int&); 00236 void set_maximum_time(const double&); 00237 00238 // Reserve training history 00239 00240 void set_reserve_parameters_history(const bool&); 00241 void set_reserve_parameters_norm_history(const bool&); 00242 00243 void set_reserve_evaluation_history(const bool&); 00244 void set_reserve_generalization_evaluation_history(const bool&); 00245 void set_reserve_gradient_history(const bool&); 00246 void set_reserve_gradient_norm_history(const bool&); 00247 00248 void set_reserve_training_direction_history(const bool&); 00249 void set_reserve_training_rate_history(const bool&); 00250 void set_reserve_elapsed_time_history(const bool&); 00251 00252 void set_reserve_all_training_history(const bool&); 00253 00254 // Utilities 00255 00256 void set_display_period(const unsigned int&); 00257 00258 // Training direction methods 00259 00260 double calculate_PR_parameter(const Vector<double>&, const Vector<double>&) const; 00261 double calculate_FR_parameter(const Vector<double>&, const Vector<double>&) const; 00262 00263 Vector<double> calculate_PR_training_direction(const Vector<double>&, const Vector<double>&, const Vector<double>&) const; 00264 Vector<double> calculate_FR_training_direction(const Vector<double>&, const Vector<double>&, const Vector<double>&) const; 00265 00266 Vector<double> calculate_training_direction(const Vector<double>&, const Vector<double>&, const Vector<double>&) const; 00267 00268 Vector<double> calculate_gradient_descent_training_direction(const Vector<double>&) const; 00269 00270 // Training methods 00271 00272 ConjugateGradientResults* perform_training(void); 00273 00274 std::string write_training_algorithm_type(void) const; 00275 00276 // Serialization methods 00277 00278 TiXmlElement* to_XML(void) const; 00279 void from_XML(TiXmlElement*); 00280 00281 private: 00282 00283 TrainingDirectionMethod training_direction_method; 00284 00286 00287 TrainingRateAlgorithm training_rate_algorithm; 00288 00290 00291 double warning_parameters_norm; 00292 00294 00295 double warning_gradient_norm; 00296 00298 00299 double warning_training_rate; 00300 00302 00303 double error_parameters_norm; 00304 00306 00307 double error_gradient_norm; 00308 00310 00311 double error_training_rate; 00312 00313 00314 // STOPPING CRITERIA 00315 00317 00318 double minimum_parameters_increment_norm; 00319 00321 00322 double minimum_performance_increase; 00323 00325 00326 double performance_goal; 00327 00329 00330 double gradient_norm_goal; 00331 00332 unsigned int maximum_generalization_evaluation_decreases; 00333 00335 00336 unsigned int maximum_epochs_number; 00337 00339 00340 double maximum_time; 00341 00342 // TRAINING HISTORY 00343 00345 00346 bool reserve_parameters_history; 00347 00349 00350 bool reserve_parameters_norm_history; 00351 00353 00354 bool reserve_evaluation_history; 00355 00357 00358 bool reserve_gradient_history; 00359 00361 00362 bool reserve_gradient_norm_history; 00363 00365 00366 bool reserve_training_direction_history; 00367 00369 00370 bool reserve_training_rate_history; 00371 00373 00374 bool reserve_elapsed_time_history; 00375 00377 00378 bool reserve_generalization_evaluation_history; 00379 00381 00382 unsigned int display_period; 00383 00384 00385 }; 00386 00387 } 00388 00389 #endif 00390 00391 00392 // OpenNN: Open Neural Networks Library. 00393 // Copyright (C) 2005-2012 Roberto Lopez 00394 // 00395 // This library is free software; you can redistribute it and/or 00396 // modify it under the terms of the GNU Lesser General Public 00397 // License as published by the Free Software Foundation; either 00398 // version 2.1 of the License, or any later version. 00399 // 00400 // This library is distributed in the hope that it will be useful, 00401 // but WITHOUT ANY WARRANTY; without even the implied warranty of 00402 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 00403 // Lesser General Public License for more details. 00404 00405 // You should have received a copy of the GNU Lesser General Public 00406 // License along with this library; if not, write to the Free Software 00407 // Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA 00408