00001 /****************************************************************************************************************/ 00002 /* */ 00003 /* OpenNN: Open Neural Networks Library */ 00004 /* www.opennn.cimne.com */ 00005 /* */ 00006 /* G R A D I E N T D E S C E N T C L A S S H E A D E R */ 00007 /* */ 00008 /* Roberto Lopez */ 00009 /* International Center for Numerical Methods in Engineering (CIMNE) */ 00010 /* Technical University of Catalonia (UPC) */ 00011 /* Barcelona, Spain */ 00012 /* E-mail: rlopez@cimne.upc.edu */ 00013 /* */ 00014 /****************************************************************************************************************/ 00015 00016 #ifndef __GRADIENTDESCENT_H__ 00017 #define __GRADIENTDESCENT_H__ 00018 00019 // OpenNN includes 00020 00021 #include "../performance_functional/performance_functional.h" 00022 00023 #include "training_algorithm.h" 00024 #include "training_rate_algorithm.h" 00025 00026 00027 namespace OpenNN 00028 { 00029 00032 00033 class GradientDescent : public TrainingAlgorithm 00034 { 00035 00036 public: 00037 00038 // DEFAULT CONSTRUCTOR 00039 00040 explicit GradientDescent(void); 00041 00042 // PERFORMANCE FUNCTIONAL CONSTRUCTOR 00043 00044 explicit GradientDescent(PerformanceFunctional*); 00045 00046 // XML CONSTRUCTOR 00047 00048 explicit GradientDescent(TiXmlElement*); 00049 00050 00051 // DESTRUCTOR 00052 00053 virtual ~GradientDescent(void); 00054 00055 // STRUCTURES 00056 00060 00061 struct GradientDescentResults : public TrainingAlgorithm::Results 00062 { 00063 // Training history 00064 00066 00067 Vector< Vector<double> > parameters_history; 00068 00070 00071 Vector<double> parameters_norm_history; 00072 00074 00075 Vector<double> evaluation_history; 00076 00078 00079 Vector<double> generalization_evaluation_history; 00080 00082 00083 Vector< Vector<double> > gradient_history; 00084 00086 00087 Vector<double> gradient_norm_history; 00088 00090 00091 Vector< Vector<double> > training_direction_history; 00092 00094 00095 Vector<double> training_rate_history; 00096 00098 00099 Vector<double> elapsed_time_history; 00100 00101 // Final values 00102 00104 00105 Vector<double> final_parameters; 00106 00108 00109 double final_parameters_norm; 00110 00112 00113 double final_evaluation; 00114 00116 00117 double final_generalization_evaluation; 00118 00120 00121 Vector<double> final_gradient; 00122 00124 00125 double final_gradient_norm; 00126 00128 00129 Vector<double> final_training_direction; 00130 00132 00133 double final_training_rate; 00134 00136 00137 double elapsed_time; 00138 00139 void resize_training_history(const unsigned int&); 00140 std::string to_string(void) const; 00141 }; 00142 00143 // METHODS 00144 00145 const TrainingRateAlgorithm& get_training_rate_algorithm(void) const; 00146 TrainingRateAlgorithm* get_training_rate_algorithm_pointer(void); 00147 00148 // Training parameters 00149 00150 const double& get_warning_parameters_norm(void) const; 00151 const double& get_warning_gradient_norm(void) const; 00152 const double& get_warning_training_rate(void) const; 00153 00154 const double& get_error_parameters_norm(void) const; 00155 const double& get_error_gradient_norm(void) const; 00156 const double& get_error_training_rate(void) const; 00157 00158 // Stopping criteria 00159 00160 const double& get_minimum_parameters_increment_norm(void) const; 00161 00162 const double& get_minimum_performance_increase(void) const; 00163 const double& get_performance_goal(void) const; 00164 const double& get_gradient_norm_goal(void) const; 00165 const unsigned int& get_maximum_generalization_evaluation_decreases(void) const; 00166 00167 const unsigned int& get_maximum_epochs_number(void) const; 00168 const double& get_maximum_time(void) const; 00169 00170 // Reserve training history 00171 00172 const bool& get_reserve_parameters_history(void) const; 00173 const bool& get_reserve_parameters_norm_history(void) const; 00174 00175 const bool& get_reserve_evaluation_history(void) const; 00176 const bool& get_reserve_gradient_history(void) const; 00177 const bool& get_reserve_gradient_norm_history(void) const; 00178 const bool& get_reserve_generalization_evaluation_history(void) const; 00179 00180 const bool& get_reserve_training_direction_history(void) const; 00181 const bool& get_reserve_training_rate_history(void) const; 00182 const bool& get_reserve_elapsed_time_history(void) const; 00183 00184 // Utilities 00185 00186 const unsigned int& get_display_period(void) const; 00187 00188 // Set methods 00189 00190 void set_training_rate_algorithm(const TrainingRateAlgorithm&); 00191 00192 00193 void set_default(void); 00194 00195 void set_reserve_all_training_history(const bool&); 00196 00197 00198 // Training parameters 00199 00200 void set_warning_parameters_norm(const double&); 00201 void set_warning_gradient_norm(const double&); 00202 void set_warning_training_rate(const double&); 00203 00204 void set_error_parameters_norm(const double&); 00205 void set_error_gradient_norm(const double&); 00206 void set_error_training_rate(const double&); 00207 00208 // Stopping criteria 00209 00210 void set_minimum_parameters_increment_norm(const double&); 00211 00212 void set_minimum_performance_increase(const double&); 00213 void set_performance_goal(const double&); 00214 void set_gradient_norm_goal(const double&); 00215 void set_maximum_generalization_evaluation_decreases(const unsigned int&); 00216 00217 void set_maximum_epochs_number(const unsigned int&); 00218 void set_maximum_time(const double&); 00219 00220 // Reserve training history 00221 00222 void set_reserve_parameters_history(const bool&); 00223 void set_reserve_parameters_norm_history(const bool&); 00224 00225 void set_reserve_evaluation_history(const bool&); 00226 void set_reserve_gradient_history(const bool&); 00227 void set_reserve_gradient_norm_history(const bool&); 00228 void set_reserve_generalization_evaluation_history(const bool&); 00229 00230 void set_reserve_training_direction_history(const bool&); 00231 void set_reserve_training_rate_history(const bool&); 00232 void set_reserve_elapsed_time_history(const bool&); 00233 00234 // Utilities 00235 00236 void set_display_period(const unsigned int&); 00237 00238 // Training methods 00239 00240 Vector<double> calculate_training_direction(const Vector<double>&) const; 00241 00242 GradientDescentResults* perform_training(void); 00243 00244 std::string write_training_algorithm_type(void) const; 00245 00246 // Serialization methods 00247 00248 TiXmlElement* to_XML(void) const; 00249 void from_XML(TiXmlElement*); 00250 00251 private: 00252 00253 // TRAINING OPERATORS 00254 00256 00257 TrainingRateAlgorithm training_rate_algorithm; 00258 00259 // TRAINING PARAMETERS 00260 00262 00263 double warning_parameters_norm; 00264 00266 00267 double warning_gradient_norm; 00268 00270 00271 double warning_training_rate; 00272 00274 00275 double error_parameters_norm; 00276 00278 00279 double error_gradient_norm; 00280 00282 00283 double error_training_rate; 00284 00285 00286 // STOPPING CRITERIA 00287 00289 00290 double minimum_parameters_increment_norm; 00291 00293 00294 double minimum_performance_increase; 00295 00297 00298 double performance_goal; 00299 00301 00302 double gradient_norm_goal; 00303 00304 unsigned int maximum_generalization_evaluation_decreases; 00305 00307 00308 unsigned int maximum_epochs_number; 00309 00311 00312 double maximum_time; 00313 00314 // TRAINING HISTORY 00315 00317 00318 bool reserve_parameters_history; 00319 00321 00322 bool reserve_parameters_norm_history; 00323 00325 00326 bool reserve_evaluation_history; 00327 00329 00330 bool reserve_gradient_history; 00331 00333 00334 bool reserve_gradient_norm_history; 00335 00337 00338 bool reserve_training_direction_history; 00339 00341 00342 bool reserve_training_rate_history; 00343 00345 00346 bool reserve_elapsed_time_history; 00347 00349 00350 bool reserve_generalization_evaluation_history; 00351 00353 00354 unsigned int display_period; 00355 00356 }; 00357 00358 } 00359 00360 #endif