00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016 #ifndef __TRAININGRATEALGORITHM_H__
00017 #define __TRAININGRATEALGORITHM_H__
00018
00019
00020
00021 #include "../neural_network/neural_network.h"
00022 #include "../performance_functional/performance_functional.h"
00023
00024
00025
00026 #include "../../parsers/tinyxml/tinyxml.h"
00027
00028 namespace OpenNN
00029 {
00030
00033
00034 class TrainingRateAlgorithm
00035 {
00036
00037 public:
00038
00039
00040
00042
00043 enum TrainingRateMethod{Fixed, GoldenSection, BrentMethod};
00044
00045
00046
00047 explicit TrainingRateAlgorithm(void);
00048
00049
00050
00051 explicit TrainingRateAlgorithm(PerformanceFunctional*);
00052
00053
00054
00055 explicit TrainingRateAlgorithm(TiXmlElement*);
00056
00057
00058
00059 virtual ~TrainingRateAlgorithm(void);
00060
00061
00062
00063
00064
00065
00066 PerformanceFunctional* get_performance_functional_pointer(void);
00067
00068
00069
00070 const TrainingRateMethod& get_training_rate_method(void) const;
00071 std::string write_training_rate_method(void) const;
00072
00073
00074
00075 const double& get_first_training_rate(void) const;
00076 const double& get_bracketing_factor(void) const;
00077 const double& get_training_rate_tolerance(void) const;
00078
00079 const double& get_warning_training_rate(void) const;
00080
00081 const double& get_error_training_rate(void) const;
00082
00083
00084
00085 const bool& get_display(void) const;
00086
00087
00088
00089 void set(void);
00090 void set(PerformanceFunctional*);
00091
00092 void set_performance_functional_pointer(PerformanceFunctional*);
00093
00094
00095
00096 void set_training_rate_method(const TrainingRateMethod&);
00097 void set_training_rate_method(const std::string&);
00098
00099
00100
00101 void set_first_training_rate(const double&);
00102 void set_bracketing_factor(const double&);
00103 void set_training_rate_tolerance(const double&);
00104
00105 void set_warning_training_rate(const double&);
00106
00107 void set_error_training_rate(const double&);
00108
00109
00110
00111 void set_display(const bool&);
00112
00113 virtual void set_default(void);
00114
00115
00116
00117 double calculate_golden_section_training_rate(const Vector<double>&, const Vector<double>&, const Vector<double>&) const;
00118 double calculate_Brent_method_training_rate(const Vector<double>&, const Vector<double>&, const Vector<double>&) const;
00119
00120 Vector< Vector<double> > calculate_bracketing_training_rate(const double&, const Vector<double>&, const double&) const;
00121
00122 Vector<double> calculate_fixed_directional_point(const double&, const Vector<double>&, const double&) const;
00123 Vector<double> calculate_golden_section_directional_point(const double&, const Vector<double>&, const double&) const;
00124 Vector<double> calculate_Brent_method_directional_point(const double&, const Vector<double>&, const double&) const;
00125
00126 Vector<double> calculate_directional_point(const double&, const Vector<double>&, const double&) const;
00127
00128
00129
00130 virtual TiXmlElement* to_XML(void) const;
00131 virtual void from_XML(TiXmlElement*);
00132
00133
00134 protected:
00135
00136
00137
00139
00140 PerformanceFunctional* performance_functional_pointer;
00141
00142
00143
00144
00146
00147 TrainingRateMethod training_rate_method;
00148
00150
00151 double bracketing_factor;
00152
00154
00155 double first_training_rate;
00156
00158
00159 double training_rate_tolerance;
00160
00162
00163 double warning_training_rate;
00164
00166
00167 double error_training_rate;
00168
00169
00170
00172
00173 bool display;
00174 };
00175
00176 }
00177
00178 #endif
00179
00180
00181
00182
00183
00184
00185
00186
00187
00188
00189
00190
00191
00192
00193
00194
00195
00196
00197