00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018 #include <iostream>
00019 #include <fstream>
00020 #include <algorithm>
00021 #include <functional>
00022 #include <limits>
00023 #include <cmath>
00024 #include <ctime>
00025
00026
00027
00028 #include "training_algorithm.h"
00029
00030 #include "gradient_descent.h"
00031 #include "conjugate_gradient.h"
00032 #include "quasi_newton_method.h"
00033
00034
00035
00036 #include "../../parsers/tinyxml/tinyxml.h"
00037
00038 namespace OpenNN
00039 {
00040
00041
00042
00043
00046
00047 TrainingAlgorithm::TrainingAlgorithm(void)
00048 : performance_functional_pointer(NULL)
00049 {
00050 set_default();
00051 }
00052
00053
00054
00055
00059
00060 TrainingAlgorithm::TrainingAlgorithm(PerformanceFunctional* new_performance_functional_pointer)
00061 : performance_functional_pointer(new_performance_functional_pointer)
00062 {
00063 set_default();
00064 }
00065
00066
00067
00068
00072
00073 TrainingAlgorithm::TrainingAlgorithm(TiXmlElement* training_algorithm_element)
00074 : performance_functional_pointer(NULL)
00075 {
00076 from_XML(training_algorithm_element);
00077 }
00078
00079
00080
00081
00083
00084 TrainingAlgorithm::~TrainingAlgorithm(void)
00085 {
00086 }
00087
00088
00089
00090
00091
00092
00095
00096 PerformanceFunctional* TrainingAlgorithm::get_performance_functional_pointer(void) const
00097 {
00098 return(performance_functional_pointer);
00099 }
00100
00101
00102
00103
00106
00107 const bool& TrainingAlgorithm::get_display(void) const
00108 {
00109 return(display);
00110 }
00111
00112
00113
00114
00117
00118 void TrainingAlgorithm::set(void)
00119 {
00120 performance_functional_pointer = NULL;
00121
00122 set_default();
00123 }
00124
00125
00126
00127
00131
00132 void TrainingAlgorithm::set(PerformanceFunctional* new_performance_functional_pointer)
00133 {
00134 performance_functional_pointer = new_performance_functional_pointer;
00135
00136 set_default();
00137 }
00138
00139
00140
00141
00144
00145 void TrainingAlgorithm::set_performance_functional_pointer(PerformanceFunctional* new_performance_functional_pointer)
00146 {
00147 performance_functional_pointer = new_performance_functional_pointer;
00148 }
00149
00150
00151
00152
00157
00158 void TrainingAlgorithm::set_display(const bool& new_display)
00159 {
00160 display = new_display;
00161 }
00162
00163
00164
00165
00167
00168 void TrainingAlgorithm::set_default(void)
00169 {
00170 display = true;
00171 }
00172
00173
00174
00175
00177
00178 std::string TrainingAlgorithm::write_training_algorithm_type(void) const
00179 {
00180 return("USER_TRAINING_ALGORITHM");
00181 }
00182
00183
00184
00185
00190
00191 void TrainingAlgorithm::check(void) const
00192 {
00193 std::ostringstream buffer;
00194
00195 if(!performance_functional_pointer)
00196 {
00197 buffer << "OpenNN Exception: TrainingAlgorithm class.\n"
00198 << "void check(void) const method.\n"
00199 << "Pointer to performance functional is NULL.\n";
00200
00201 throw std::logic_error(buffer.str().c_str());
00202 }
00203
00204 const NeuralNetwork* neural_network_pointer = performance_functional_pointer->get_neural_network_pointer();
00205
00206 if(neural_network_pointer == NULL)
00207 {
00208 buffer << "OpenNN Exception: TrainingAlgorithm class.\n"
00209 << "void check(void) const method.\n"
00210 << "Pointer to neural network is NULL.\n";
00211
00212 throw std::logic_error(buffer.str().c_str());
00213 }
00214 }
00215
00216
00217
00218
00220
00221 std::string TrainingAlgorithm::to_string(void) const
00222 {
00223 std::ostringstream buffer;
00224
00225 buffer << "Training strategy\n"
00226 << "Display: " << display << "\n";
00227
00228 return(buffer.str());
00229 }
00230
00231
00232
00233
00236
00237 TiXmlElement* TrainingAlgorithm::to_XML(void) const
00238 {
00239 std::ostringstream buffer;
00240
00241
00242
00243 TiXmlElement* training_algorithm_element = new TiXmlElement("TrainingAlgorithm");
00244 training_algorithm_element->SetAttribute("Version", 4);
00245
00246
00247 {
00248 TiXmlElement* display_element = new TiXmlElement("Display");
00249 training_algorithm_element->LinkEndChild(display_element);
00250
00251 buffer.str("");
00252 buffer << display;
00253
00254 TiXmlText* display_text = new TiXmlText(buffer.str().c_str());
00255 display_element->LinkEndChild(display_text);
00256 }
00257
00258 return(training_algorithm_element);
00259 }
00260
00261
00262
00263
00266
00267 void TrainingAlgorithm::from_XML(TiXmlElement* training_algorithm_element)
00268 {
00269
00270
00271 TiXmlElement* display_element = training_algorithm_element->FirstChildElement("Display");
00272
00273 if(display_element)
00274 {
00275 std::string new_display = display_element->GetText();
00276
00277 try
00278 {
00279 set_display(new_display != "0");
00280 }
00281 catch(std::exception& e)
00282 {
00283 std::cout << e.what() << std::endl;
00284 }
00285 }
00286 }
00287
00288
00289
00290
00292
00293 void TrainingAlgorithm::print(void) const
00294 {
00295 std::cout << to_string();
00296 }
00297
00298
00299
00300
00303
00304 void TrainingAlgorithm::save(const std::string& filename) const
00305 {
00306 std::ostringstream buffer;
00307
00308 TiXmlDocument document;
00309
00310
00311
00312 TiXmlDeclaration* declaration = new TiXmlDeclaration("1.0", "", "");
00313 document.LinkEndChild(declaration);
00314
00315
00316
00317 TiXmlElement* training_algorithm_element = to_XML();
00318 document.LinkEndChild(training_algorithm_element);
00319
00320 document.SaveFile(filename.c_str());
00321 }
00322
00323
00324
00325
00329
00330 void TrainingAlgorithm::load(const std::string& filename)
00331 {
00332 set_default();
00333
00334 std::ostringstream buffer;
00335
00336 TiXmlDocument document(filename.c_str());
00337
00338 if (!document.LoadFile())
00339 {
00340 buffer << "OpenNN Exception: TrainingAlgorithm class.\n"
00341 << "void load(const std::string&) method.\n"
00342 << "Cannot load XML file " << filename << ".\n";
00343
00344 throw std::logic_error(buffer.str());
00345 }
00346
00347
00348
00349 TiXmlElement* training_algorithm_element = document.FirstChildElement("TrainingAlgorithm");
00350
00351 if(!training_algorithm_element)
00352 {
00353 buffer << "OpenNN Exception: TrainingAlgorithm class.\n"
00354 << "void load(const std::string&) method.\n"
00355 << "File " << filename << " is not a valid training algorithm file.\n";
00356
00357 throw std::logic_error(buffer.str());
00358 }
00359 }
00360
00361 }
00362
00363
00364
00365
00366
00367
00368
00369
00370
00371
00372
00373
00374
00375
00376
00377
00378
00379