00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018 #include <iostream>
00019 #include <fstream>
00020 #include <algorithm>
00021 #include <functional>
00022 #include <limits>
00023 #include <cmath>
00024 #include <ctime>
00025
00026
00027
00028 #include "training_strategy.h"
00029
00030 #include "random_search.h"
00031 #include "evolutionary_algorithm.h"
00032
00033 #include "gradient_descent.h"
00034 #include "conjugate_gradient.h"
00035 #include "quasi_newton_method.h"
00036 #include "levenberg_marquardt_algorithm.h"
00037
00038 #include "newton_method.h"
00039
00040
00041
00042 #include "../../parsers/tinyxml/tinyxml.h"
00043
00044 namespace OpenNN
00045 {
00046
00047
00048
00052
00053 TrainingStrategy::TrainingStrategy(void)
00054 : performance_functional_pointer(NULL),
00055 initialization_training_algorithm_pointer(NULL),
00056 main_training_algorithm_pointer(NULL),
00057 refinement_training_algorithm_pointer(NULL)
00058 {
00059 set_default();
00060
00061 construct_main_training_algorithm(main_training_algorithm_type);
00062 }
00063
00064
00065
00066
00071
00072 TrainingStrategy::TrainingStrategy(PerformanceFunctional* new_performance_functional_pointer)
00073 : performance_functional_pointer(new_performance_functional_pointer),
00074 initialization_training_algorithm_pointer(NULL),
00075 main_training_algorithm_pointer(NULL),
00076 refinement_training_algorithm_pointer(NULL)
00077 {
00078 set_default();
00079
00080 construct_main_training_algorithm(main_training_algorithm_type);
00081 }
00082
00083
00084
00085
00090
00091 TrainingStrategy::TrainingStrategy(TiXmlElement* training_strategy_element)
00092 : performance_functional_pointer(NULL),
00093 initialization_training_algorithm_pointer(NULL),
00094 main_training_algorithm_pointer(NULL),
00095 refinement_training_algorithm_pointer(NULL)
00096 {
00097 set_default();
00098
00099 from_XML(training_strategy_element);
00100 }
00101
00102
00103
00104
00105
00110
00111 TrainingStrategy::TrainingStrategy(const std::string& filename)
00112 : performance_functional_pointer(NULL),
00113 initialization_training_algorithm_pointer(NULL),
00114 main_training_algorithm_pointer(NULL),
00115 refinement_training_algorithm_pointer(NULL)
00116 {
00117 set_default();
00118
00119 load(filename);
00120 }
00121
00122
00123
00124
00127
00128 TrainingStrategy::~TrainingStrategy(void)
00129 {
00130 delete initialization_training_algorithm_pointer;
00131 delete main_training_algorithm_pointer;
00132 delete refinement_training_algorithm_pointer;
00133 }
00134
00135
00136
00137
00138
00139
00141
00142 PerformanceFunctional* TrainingStrategy::get_performance_functional_pointer(void) const
00143 {
00144 return(performance_functional_pointer);
00145 }
00146
00147
00148
00149
00151
00152 TrainingAlgorithm* TrainingStrategy::get_initialization_training_algorithm_pointer(void) const
00153 {
00154 return(initialization_training_algorithm_pointer);
00155 }
00156
00157
00158
00159
00161
00162 TrainingAlgorithm* TrainingStrategy::get_main_training_algorithm_pointer(void) const
00163 {
00164 return(main_training_algorithm_pointer);
00165 }
00166
00167
00168
00169
00171
00172 TrainingAlgorithm* TrainingStrategy::get_refinement_training_algorithm_pointer(void) const
00173 {
00174 return(refinement_training_algorithm_pointer);
00175 }
00176
00177
00178
00179
00181
00182 const TrainingStrategy::TrainingAlgorithmType& TrainingStrategy::get_initialization_training_algorithm_type(void) const
00183 {
00184 return(initialization_training_algorithm_type);
00185 }
00186
00187
00188
00189
00191
00192 const TrainingStrategy::TrainingAlgorithmType& TrainingStrategy::get_main_training_algorithm_type(void) const
00193 {
00194 return(main_training_algorithm_type);
00195 }
00196
00197
00198
00199
00201
00202 const TrainingStrategy::TrainingAlgorithmType& TrainingStrategy::get_refinement_training_algorithm_type(void) const
00203 {
00204 return(refinement_training_algorithm_type);
00205 }
00206
00207
00208
00209
00211
00212 std::string TrainingStrategy::write_initialization_training_algorithm_type(void) const
00213 {
00214 if(initialization_training_algorithm_type == NONE)
00215 {
00216 return("NONE");
00217 }
00218 else if(initialization_training_algorithm_type == RANDOM_SEARCH)
00219 {
00220 return("RANDOM_SEARCH");
00221 }
00222 else if(initialization_training_algorithm_type == EVOLUTIONARY_ALGORITHM)
00223 {
00224 return("EVOLUTIONARY_ALGORITHM");
00225 }
00226 else if(initialization_training_algorithm_type == GRADIENT_DESCENT)
00227 {
00228 return("GRADIENT_DESCENT");
00229 }
00230 else if(initialization_training_algorithm_type == CONJUGATE_GRADIENT)
00231 {
00232 return("CONJUGATE_GRADIENT");
00233 }
00234 else if(initialization_training_algorithm_type == QUASI_NEWTON_METHOD)
00235 {
00236 return("QUASI_NEWTON_METHOD");
00237 }
00238 else if(initialization_training_algorithm_type == LEVENBERG_MARQUARDT_ALGORITHM)
00239 {
00240 return("LEVENBERG_MARQUARDT_ALGORITHM");
00241 }
00242 else if(initialization_training_algorithm_type == NEWTON_METHOD)
00243 {
00244 return("NEWTON_METHOD");
00245 }
00246 else if(initialization_training_algorithm_type == USER_TRAINING_ALGORITHM)
00247 {
00248 return("USER_TRAINING_ALGORITHM");
00249 }
00250 else
00251 {
00252 std::ostringstream buffer;
00253
00254 buffer << "OpenNN Exception: TrainingStrategy class.\n"
00255 << "std::string write_initialization_training_algorithm_type(void) const method.\n"
00256 << "Unknown training algorithm type.\n";
00257
00258 throw std::logic_error(buffer.str());
00259 }
00260 }
00261
00262
00263
00264
00266
00267 std::string TrainingStrategy::write_main_training_algorithm_type(void) const
00268 {
00269 if(main_training_algorithm_type == NONE)
00270 {
00271 return("NONE");
00272 }
00273 else if(main_training_algorithm_type == RANDOM_SEARCH)
00274 {
00275 return("RANDOM_SEARCH");
00276 }
00277 else if(main_training_algorithm_type == EVOLUTIONARY_ALGORITHM)
00278 {
00279 return("EVOLUTIONARY_ALGORITHM");
00280 }
00281 else if(main_training_algorithm_type == GRADIENT_DESCENT)
00282 {
00283 return("GRADIENT_DESCENT");
00284 }
00285 else if(main_training_algorithm_type == CONJUGATE_GRADIENT)
00286 {
00287 return("CONJUGATE_GRADIENT");
00288 }
00289 else if(main_training_algorithm_type == QUASI_NEWTON_METHOD)
00290 {
00291 return("QUASI_NEWTON_METHOD");
00292 }
00293 else if(main_training_algorithm_type == LEVENBERG_MARQUARDT_ALGORITHM)
00294 {
00295 return("LEVENBERG_MARQUARDT_ALGORITHM");
00296 }
00297 else if(main_training_algorithm_type == NEWTON_METHOD)
00298 {
00299 return("NEWTON_METHOD");
00300 }
00301 else if(main_training_algorithm_type == USER_TRAINING_ALGORITHM)
00302 {
00303 return("USER_TRAINING_ALGORITHM");
00304 }
00305 else
00306 {
00307 std::ostringstream buffer;
00308
00309 buffer << "OpenNN Exception: TrainingStrategy class.\n"
00310 << "std::string write_main_training_algorithm_type(void) const method.\n"
00311 << "Unknown training algorithm type.\n";
00312
00313 throw std::logic_error(buffer.str());
00314 }
00315 }
00316
00317
00318
00319
00321
00322 std::string TrainingStrategy::write_refinement_training_algorithm_type(void) const
00323 {
00324 if(refinement_training_algorithm_type == NONE)
00325 {
00326 return("NONE");
00327 }
00328 else if(refinement_training_algorithm_type == RANDOM_SEARCH)
00329 {
00330 return("RANDOM_SEARCH");
00331 }
00332 else if(refinement_training_algorithm_type == EVOLUTIONARY_ALGORITHM)
00333 {
00334 return("EVOLUTIONARY_ALGORITHM");
00335 }
00336 else if(refinement_training_algorithm_type == GRADIENT_DESCENT)
00337 {
00338 return("GRADIENT_DESCENT");
00339 }
00340 else if(refinement_training_algorithm_type == CONJUGATE_GRADIENT)
00341 {
00342 return("CONJUGATE_GRADIENT");
00343 }
00344 else if(refinement_training_algorithm_type == QUASI_NEWTON_METHOD)
00345 {
00346 return("QUASI_NEWTON_METHOD");
00347 }
00348 else if(refinement_training_algorithm_type == LEVENBERG_MARQUARDT_ALGORITHM)
00349 {
00350 return("LEVENBERG_MARQUARDT_ALGORITHM");
00351 }
00352 else if(refinement_training_algorithm_type == NEWTON_METHOD)
00353 {
00354 return("NEWTON_METHOD");
00355 }
00356 else if(refinement_training_algorithm_type == USER_TRAINING_ALGORITHM)
00357 {
00358 return("USER_TRAINING_ALGORITHM");
00359 }
00360 else
00361 {
00362 std::ostringstream buffer;
00363
00364 buffer << "OpenNN Exception: TrainingStrategy class.\n"
00365 << "std::string write_refinement_training_algorithm_type(void) const method.\n"
00366 << "Unknown training algorithm type.\n";
00367
00368 throw std::logic_error(buffer.str());
00369 }
00370 }
00371
00372
00373
00374
00376
00377 const bool& TrainingStrategy::get_initialization_training_algorithm_flag(void)
00378 {
00379 return(initialization_training_algorithm_flag);
00380 }
00381
00382
00383
00384
00386
00387 const bool& TrainingStrategy::get_main_training_algorithm_flag(void)
00388 {
00389 return(main_training_algorithm_flag);
00390 }
00391
00392
00393
00394
00396
00397 const bool& TrainingStrategy::get_refinement_training_algorithm_flag(void)
00398 {
00399 return(refinement_training_algorithm_flag);
00400 }
00401
00402
00403
00404
00407
00408 const bool& TrainingStrategy::get_display(void) const
00409 {
00410 return(display);
00411 }
00412
00413
00414
00415
00419
00420 void TrainingStrategy::set(void)
00421 {
00422 performance_functional_pointer = NULL;
00423
00424 destruct_initialization_training_algorithm();
00425 destruct_main_training_algorithm();
00426 destruct_refinement_training_algorithm();
00427
00428 set_default();
00429 }
00430
00431
00432
00433
00438
00439 void TrainingStrategy::set(PerformanceFunctional* new_performance_functional_pointer)
00440 {
00441 performance_functional_pointer = new_performance_functional_pointer;
00442
00443 set_default();
00444
00445 destruct_initialization_training_algorithm();
00446 construct_main_training_algorithm(main_training_algorithm_type);
00447 destruct_refinement_training_algorithm();
00448 }
00449
00450
00451
00452
00455
00456 void TrainingStrategy::set_initialization_training_algorithm_flag(const bool& new_initialization_training_algorithm_flag)
00457 {
00458 initialization_training_algorithm_flag = new_initialization_training_algorithm_flag;
00459 }
00460
00461
00462
00463
00466
00467 void TrainingStrategy::set_main_training_algorithm_flag(const bool& new_main_training_algorithm_flag)
00468 {
00469 main_training_algorithm_flag = new_main_training_algorithm_flag;
00470 }
00471
00472
00473
00474
00477
00478 void TrainingStrategy::set_refinement_training_algorithm_flag(const bool& new_refinement_training_algorithm_flag)
00479 {
00480 refinement_training_algorithm_flag = new_refinement_training_algorithm_flag;
00481 }
00482
00483
00484
00485
00489
00490 void TrainingStrategy::set_initialization_training_algorithm_type(const TrainingAlgorithmType& new_initialization_training_algorithm_type)
00491 {
00492 initialization_training_algorithm_type = new_initialization_training_algorithm_type;
00493 }
00494
00495
00496
00497
00501
00502 void TrainingStrategy::set_main_training_algorithm_type(const TrainingAlgorithmType& new_main_training_algorithm_type)
00503 {
00504 main_training_algorithm_type = new_main_training_algorithm_type;
00505 }
00506
00507
00508
00509
00513
00514 void TrainingStrategy::set_refinement_training_algorithm_type(const TrainingAlgorithmType& new_refinement_training_algorithm_type)
00515 {
00516 refinement_training_algorithm_type = new_refinement_training_algorithm_type;
00517 }
00518
00519
00520
00521
00525
00526 void TrainingStrategy::set_initialization_training_algorithm_type(const std::string& new_training_algorithm_type)
00527 {
00528 if(new_training_algorithm_type == "NONE")
00529 {
00530 set_initialization_training_algorithm_type(NONE);
00531 }
00532 else if(new_training_algorithm_type == "RANDOM_SEARCH")
00533 {
00534 set_initialization_training_algorithm_type(RANDOM_SEARCH);
00535 }
00536 else if(new_training_algorithm_type == "EVOLUTIONARY_ALGORITHM")
00537 {
00538 set_initialization_training_algorithm_type(EVOLUTIONARY_ALGORITHM);
00539 }
00540 else if(new_training_algorithm_type == "GRADIENT_DESCENT")
00541 {
00542 set_initialization_training_algorithm_type(GRADIENT_DESCENT);
00543 }
00544 else if(new_training_algorithm_type == "CONJUGATE_GRADIENT")
00545 {
00546 set_initialization_training_algorithm_type(CONJUGATE_GRADIENT);
00547 }
00548 else if(new_training_algorithm_type == "QUASI_NEWTON_METHOD")
00549 {
00550 set_initialization_training_algorithm_type(QUASI_NEWTON_METHOD);
00551 }
00552 else if(new_training_algorithm_type == "LEVENBERG_MARQUARDT_ALGORITHM")
00553 {
00554 set_initialization_training_algorithm_type(LEVENBERG_MARQUARDT_ALGORITHM);
00555 }
00556 else if(new_training_algorithm_type == "NEWTON_METHOD")
00557 {
00558 set_initialization_training_algorithm_type(NEWTON_METHOD);
00559 }
00560 else if(new_training_algorithm_type == "USER_TRAINING_ALGORITHM")
00561 {
00562 set_initialization_training_algorithm_type(USER_TRAINING_ALGORITHM);
00563 }
00564 else
00565 {
00566 std::ostringstream buffer;
00567
00568 buffer << "OpenNN Exception: TrainingStrategy class.\n"
00569 << "void set_initialization_training_algorithm_type(const std::string&) method.\n"
00570 << "Unknown training algorithm type: " << new_training_algorithm_type << ".\n";
00571
00572 throw std::logic_error(buffer.str());
00573 }
00574 }
00575
00576
00577
00578
00582
00583 void TrainingStrategy::set_main_training_algorithm_type(const std::string& new_training_algorithm_type)
00584 {
00585 if(new_training_algorithm_type == "NONE")
00586 {
00587 set_main_training_algorithm_type(NONE);
00588 }
00589 else if(new_training_algorithm_type == "RANDOM_SEARCH")
00590 {
00591 set_main_training_algorithm_type(RANDOM_SEARCH);
00592 }
00593 else if(new_training_algorithm_type == "EVOLUTIONARY_ALGORITHM")
00594 {
00595 set_main_training_algorithm_type(EVOLUTIONARY_ALGORITHM);
00596 }
00597 else if(new_training_algorithm_type == "GRADIENT_DESCENT")
00598 {
00599 set_main_training_algorithm_type(GRADIENT_DESCENT);
00600 }
00601 else if(new_training_algorithm_type == "CONJUGATE_GRADIENT")
00602 {
00603 set_main_training_algorithm_type(CONJUGATE_GRADIENT);
00604 }
00605 else if(new_training_algorithm_type == "QUASI_NEWTON_METHOD")
00606 {
00607 set_main_training_algorithm_type(QUASI_NEWTON_METHOD);
00608 }
00609 else if(new_training_algorithm_type == "LEVENBERG_MARQUARDT_ALGORITHM")
00610 {
00611 set_main_training_algorithm_type(LEVENBERG_MARQUARDT_ALGORITHM);
00612 }
00613 else if(new_training_algorithm_type == "NEWTON_METHOD")
00614 {
00615 set_main_training_algorithm_type(NEWTON_METHOD);
00616 }
00617 else if(new_training_algorithm_type == "USER_TRAINING_ALGORITHM")
00618 {
00619 set_main_training_algorithm_type(USER_TRAINING_ALGORITHM);
00620 }
00621 else
00622 {
00623 std::ostringstream buffer;
00624
00625 buffer << "OpenNN Exception: TrainingStrategy class.\n"
00626 << "void set_main_training_algorithm_type(const std::string&) method.\n"
00627 << "Unknown training algorithm type: " << new_training_algorithm_type << ".\n";
00628
00629 throw std::logic_error(buffer.str());
00630 }
00631 }
00632
00633
00634
00635
00639
00640 void TrainingStrategy::set_refinement_training_algorithm_type(const std::string& new_training_algorithm_type)
00641 {
00642 if(new_training_algorithm_type == "NONE")
00643 {
00644 set_refinement_training_algorithm_type(NONE);
00645 }
00646 else if(new_training_algorithm_type == "RANDOM_SEARCH")
00647 {
00648 set_refinement_training_algorithm_type(RANDOM_SEARCH);
00649 }
00650 else if(new_training_algorithm_type == "EVOLUTIONARY_ALGORITHM")
00651 {
00652 set_refinement_training_algorithm_type(EVOLUTIONARY_ALGORITHM);
00653 }
00654 else if(new_training_algorithm_type == "GRADIENT_DESCENT")
00655 {
00656 set_refinement_training_algorithm_type(GRADIENT_DESCENT);
00657 }
00658 else if(new_training_algorithm_type == "CONJUGATE_GRADIENT")
00659 {
00660 set_refinement_training_algorithm_type(CONJUGATE_GRADIENT);
00661 }
00662 else if(new_training_algorithm_type == "QUASI_NEWTON_METHOD")
00663 {
00664 set_refinement_training_algorithm_type(QUASI_NEWTON_METHOD);
00665 }
00666 else if(new_training_algorithm_type == "LEVENBERG_MARQUARDT_ALGORITHM")
00667 {
00668 set_refinement_training_algorithm_type(LEVENBERG_MARQUARDT_ALGORITHM);
00669 }
00670 else if(new_training_algorithm_type == "NEWTON_METHOD")
00671 {
00672 set_refinement_training_algorithm_type(NEWTON_METHOD);
00673 }
00674 else if(new_training_algorithm_type == "USER_TRAINING_ALGORITHM")
00675 {
00676 set_refinement_training_algorithm_type(USER_TRAINING_ALGORITHM);
00677 }
00678 else
00679 {
00680 std::ostringstream buffer;
00681
00682 buffer << "OpenNN Exception: TrainingStrategy class.\n"
00683 << "void set_refinement_training_algorithm_type(const std::string&) method.\n"
00684 << "Unknown training algorithm type: " << new_training_algorithm_type << ".\n";
00685
00686 throw std::logic_error(buffer.str());
00687 }
00688 }
00689
00690
00691
00692
00695
00696 void TrainingStrategy::set_performance_functional_pointer(PerformanceFunctional* new_performance_functional_pointer)
00697 {
00698 performance_functional_pointer = new_performance_functional_pointer;
00699 }
00700
00701
00702
00703
00706
00707 void TrainingStrategy::set_initialization_training_algorithm_pointer(TrainingAlgorithm* new_initialization_training_algorithm_pointer)
00708 {
00709 initialization_training_algorithm_pointer = new_initialization_training_algorithm_pointer;
00710 }
00711
00712
00713
00714
00717
00718 void TrainingStrategy::set_main_training_algorithm_pointer(TrainingAlgorithm* new_main_training_algorithm_pointer)
00719 {
00720 main_training_algorithm_pointer = new_main_training_algorithm_pointer;
00721 }
00722
00723
00724
00725
00728
00729 void TrainingStrategy::set_refinement_training_algorithm_pointer(TrainingAlgorithm* new_refinement_training_algorithm_pointer)
00730 {
00731 refinement_training_algorithm_pointer = new_refinement_training_algorithm_pointer;
00732 }
00733
00734
00735
00736
00741
00742 void TrainingStrategy::set_display(const bool& new_display)
00743 {
00744 display = new_display;
00745 }
00746
00747
00748
00749
00760
00761 void TrainingStrategy::set_default(void)
00762 {
00763 initialization_training_algorithm_type = TrainingStrategy::NONE;
00764 main_training_algorithm_type = TrainingStrategy::QUASI_NEWTON_METHOD;
00765 refinement_training_algorithm_type = NONE;
00766
00767 initialization_training_algorithm_flag = false;
00768 main_training_algorithm_flag = true;
00769 refinement_training_algorithm_flag = false;
00770
00771 display = true;
00772 }
00773
00774
00775
00776
00781
00782 void TrainingStrategy::construct_initialization_training_algorithm(const TrainingAlgorithmType& new_training_algorithm_type)
00783 {
00784 if(initialization_training_algorithm_pointer)
00785 {
00786 delete initialization_training_algorithm_pointer;
00787 }
00788
00789 initialization_training_algorithm_type = new_training_algorithm_type;
00790 initialization_training_algorithm_flag = true;
00791
00792 switch(initialization_training_algorithm_type)
00793 {
00794 case RANDOM_SEARCH:
00795 {
00796 initialization_training_algorithm_pointer = new RandomSearch(performance_functional_pointer);
00797 }
00798 break;
00799
00800 case EVOLUTIONARY_ALGORITHM:
00801 {
00802 initialization_training_algorithm_pointer = new EvolutionaryAlgorithm(performance_functional_pointer);
00803 }
00804 break;
00805
00806 case GRADIENT_DESCENT:
00807 {
00808 initialization_training_algorithm_pointer = new GradientDescent(performance_functional_pointer);
00809 }
00810 break;
00811
00812 case CONJUGATE_GRADIENT:
00813 {
00814 initialization_training_algorithm_pointer = new ConjugateGradient(performance_functional_pointer);
00815 }
00816 break;
00817
00818 case QUASI_NEWTON_METHOD:
00819 {
00820 initialization_training_algorithm_pointer = new QuasiNewtonMethod(performance_functional_pointer);
00821 }
00822 break;
00823
00824 case LEVENBERG_MARQUARDT_ALGORITHM:
00825 {
00826 initialization_training_algorithm_pointer = new LevenbergMarquardtAlgorithm(performance_functional_pointer);
00827 }
00828 break;
00829
00830 case NEWTON_METHOD:
00831 {
00832 initialization_training_algorithm_pointer = new NewtonMethod(performance_functional_pointer);
00833 }
00834 break;
00835
00836 case USER_TRAINING_ALGORITHM:
00837 {
00838 initialization_training_algorithm_pointer = NULL;
00839 }
00840 break;
00841
00842 default:
00843 {
00844 std::ostringstream buffer;
00845
00846 buffer << "OpenNN Exception: TrainingStrategy class.\n"
00847 << "void construct_initialization_training_algorithm(const TrainingAlgorithmType&) method.\n"
00848 << "Unknown training algorithm type.\n";
00849
00850 throw std::logic_error(buffer.str().c_str());
00851 }
00852 break;
00853 }
00854 }
00855
00856
00857
00858
00863
00864 void TrainingStrategy::construct_main_training_algorithm(const TrainingAlgorithmType& new_training_algorithm_type)
00865 {
00866 if(main_training_algorithm_pointer)
00867 {
00868 delete main_training_algorithm_pointer;
00869 }
00870
00871 main_training_algorithm_type = new_training_algorithm_type;
00872 main_training_algorithm_flag = true;
00873
00874 switch(main_training_algorithm_type)
00875 {
00876 case RANDOM_SEARCH:
00877 {
00878 main_training_algorithm_pointer = new RandomSearch(performance_functional_pointer);
00879 }
00880 break;
00881
00882 case EVOLUTIONARY_ALGORITHM:
00883 {
00884 main_training_algorithm_pointer = new EvolutionaryAlgorithm(performance_functional_pointer);
00885 }
00886 break;
00887
00888 case GRADIENT_DESCENT:
00889 {
00890 main_training_algorithm_pointer = new GradientDescent(performance_functional_pointer);
00891 }
00892 break;
00893
00894 case CONJUGATE_GRADIENT:
00895 {
00896 main_training_algorithm_pointer = new ConjugateGradient(performance_functional_pointer);
00897 }
00898 break;
00899
00900 case QUASI_NEWTON_METHOD:
00901 {
00902 main_training_algorithm_pointer = new QuasiNewtonMethod(performance_functional_pointer);
00903 }
00904 break;
00905
00906 case LEVENBERG_MARQUARDT_ALGORITHM:
00907 {
00908 main_training_algorithm_pointer = new LevenbergMarquardtAlgorithm(performance_functional_pointer);
00909 }
00910 break;
00911
00912 case NEWTON_METHOD:
00913 {
00914 main_training_algorithm_pointer = new NewtonMethod(performance_functional_pointer);
00915 }
00916 break;
00917
00918 case USER_TRAINING_ALGORITHM:
00919 {
00920 main_training_algorithm_pointer = NULL;
00921 }
00922 break;
00923
00924 default:
00925 {
00926 std::ostringstream buffer;
00927
00928 buffer << "OpenNN Exception: TrainingStrategy class.\n"
00929 << "void construct_main_training_algorithm(const TrainingAlgorithmType&) method.\n"
00930 << "Unknown training algorithm type.\n";
00931
00932 throw std::logic_error(buffer.str().c_str());
00933 }
00934 break;
00935 }
00936 }
00937
00938
00939
00940
00945
00946 void TrainingStrategy::construct_refinement_training_algorithm(const TrainingAlgorithmType& new_training_algorithm_type)
00947 {
00948 if(refinement_training_algorithm_pointer)
00949 {
00950 delete refinement_training_algorithm_pointer;
00951 }
00952
00953 refinement_training_algorithm_type = new_training_algorithm_type;
00954 refinement_training_algorithm_flag = true;
00955
00956 switch(main_training_algorithm_type)
00957 {
00958 case RANDOM_SEARCH:
00959 {
00960 refinement_training_algorithm_pointer = new RandomSearch(performance_functional_pointer);
00961 }
00962 break;
00963
00964 case EVOLUTIONARY_ALGORITHM:
00965 {
00966 refinement_training_algorithm_pointer = new EvolutionaryAlgorithm(performance_functional_pointer);
00967 }
00968 break;
00969
00970 case GRADIENT_DESCENT:
00971 {
00972 refinement_training_algorithm_pointer = new GradientDescent(performance_functional_pointer);
00973 }
00974 break;
00975
00976 case CONJUGATE_GRADIENT:
00977 {
00978 refinement_training_algorithm_pointer = new ConjugateGradient(performance_functional_pointer);
00979 }
00980 break;
00981
00982 case QUASI_NEWTON_METHOD:
00983 {
00984 refinement_training_algorithm_pointer = new QuasiNewtonMethod(performance_functional_pointer);
00985 }
00986 break;
00987
00988 case LEVENBERG_MARQUARDT_ALGORITHM:
00989 {
00990 refinement_training_algorithm_pointer = new LevenbergMarquardtAlgorithm(performance_functional_pointer);
00991 }
00992 break;
00993
00994 case NEWTON_METHOD:
00995 {
00996 refinement_training_algorithm_pointer = new NewtonMethod(performance_functional_pointer);
00997 }
00998 break;
00999
01000 case USER_TRAINING_ALGORITHM:
01001 {
01002 refinement_training_algorithm_pointer = NULL;
01003 }
01004 break;
01005
01006 default:
01007 {
01008 std::ostringstream buffer;
01009
01010 buffer << "OpenNN Exception: TrainingStrategy class.\n"
01011 << "void construct_refinement_training_algorithm(const TrainingAlgorithmType&) method.\n"
01012 << "Unknown training algorithm type.\n";
01013
01014 throw std::logic_error(buffer.str().c_str());
01015 }
01016 break;
01017 }
01018 }
01019
01020
01021
01022
01024
01025 void TrainingStrategy::destruct_initialization_training_algorithm(void)
01026 {
01027 delete initialization_training_algorithm_pointer;
01028
01029 initialization_training_algorithm_pointer = NULL;
01030
01031 initialization_training_algorithm_type = NONE;
01032
01033 initialization_training_algorithm_flag = false;
01034 }
01035
01036
01037
01038
01040
01041 void TrainingStrategy::destruct_main_training_algorithm(void)
01042 {
01043 delete main_training_algorithm_pointer;
01044
01045 main_training_algorithm_pointer = NULL;
01046
01047 main_training_algorithm_type = NONE;
01048
01049 main_training_algorithm_flag = false;
01050 }
01051
01052
01053
01054
01056
01057 void TrainingStrategy::destruct_refinement_training_algorithm(void)
01058 {
01059 delete refinement_training_algorithm_pointer;
01060
01061 refinement_training_algorithm_pointer = NULL;
01062
01063 refinement_training_algorithm_type = NONE;
01064
01065 refinement_training_algorithm_flag = false;
01066 }
01067
01068
01069
01070
01075
01076 TrainingStrategy::Results TrainingStrategy::perform_training(void)
01077 {
01078 #ifdef _DEBUG
01079
01080 std::ostringstream buffer;
01081
01082 #endif
01083
01084 Results training_strategy_results;
01085
01086 training_strategy_results.initialization_training_algorithm_results_pointer = NULL;
01087 training_strategy_results.main_training_algorithm_results_pointer = NULL;
01088 training_strategy_results.refinement_training_algorithm_results_pointer = NULL;
01089
01090 if(initialization_training_algorithm_flag)
01091 {
01092 #ifdef _DEBUG
01093
01094 if(!initialization_training_algorithm_pointer)
01095 {
01096 buffer << "OpenNN Exception: TrainingStrategy class.\n"
01097 << "void perform_training(void) method.\n"
01098 << "Pointer to initialization training algorithm is NULL.\n";
01099
01100 throw std::logic_error(buffer.str().c_str());
01101 }
01102
01103 #endif
01104
01105 training_strategy_results.initialization_training_algorithm_results_pointer
01106 = initialization_training_algorithm_pointer->perform_training();
01107 }
01108
01109 if(main_training_algorithm_flag)
01110 {
01111 #ifdef _DEBUG
01112
01113 if(!main_training_algorithm_pointer)
01114 {
01115 buffer << "OpenNN Exception: TrainingStrategy class.\n"
01116 << "void perform_training(void) method.\n"
01117 << "Pointer to main training algorithm is NULL.\n";
01118
01119 throw std::logic_error(buffer.str().c_str());
01120 }
01121
01122 #endif
01123
01124 training_strategy_results.initialization_training_algorithm_results_pointer = main_training_algorithm_pointer->perform_training();
01125 }
01126
01127 if(refinement_training_algorithm_flag)
01128 {
01129 #ifdef _DEBUG
01130
01131 if(!initialization_training_algorithm_pointer)
01132 {
01133 buffer << "OpenNN Exception: TrainingStrategy class.\n"
01134 << "void perform_training(void) method.\n"
01135 << "Pointer to refinement training algorithm is NULL.\n";
01136
01137 throw std::logic_error(buffer.str().c_str());
01138 }
01139
01140 #endif
01141
01142 training_strategy_results.initialization_training_algorithm_results_pointer = refinement_training_algorithm_pointer->perform_training();
01143 }
01144
01145 return(training_strategy_results);
01146 }
01147
01148
01149
01150
01152
01153 std::string TrainingStrategy::to_string(void) const
01154 {
01155 std::ostringstream buffer;
01156
01157 buffer << "Training strategy\n"
01158 << "Initialization training algorithm type: " << write_initialization_training_algorithm_type() << "\n"
01159 << "Main training algorithm type: " << write_main_training_algorithm_type() << "\n"
01160 << "Refinement training algorithm type: " << write_refinement_training_algorithm_type() << "\n"
01161 << "Initialization training algorithm flag: " << initialization_training_algorithm_flag << "\n"
01162 << "Main training algorithm flag: " << main_training_algorithm_flag << "\n"
01163 << "Refinement training algorithm flag: " << refinement_training_algorithm_flag << "\n";
01164
01165 if(initialization_training_algorithm_pointer)
01166 {
01167 buffer << "Initialization training algorithm:\n"
01168 << initialization_training_algorithm_pointer->to_string();
01169 }
01170
01171 if(main_training_algorithm_pointer)
01172 {
01173 buffer << "Main training algorithm:\n"
01174 << main_training_algorithm_pointer->to_string();
01175 }
01176
01177 if(refinement_training_algorithm_pointer)
01178 {
01179 buffer << "Refinement training algorithm:\n"
01180 << refinement_training_algorithm_pointer->to_string();
01181 }
01182
01183 return(buffer.str());
01184 }
01185
01186
01187
01188
01190
01191 void TrainingStrategy::print(void) const
01192 {
01193 std::cout << to_string();
01194 }
01195
01196
01197
01198
01201
01202 TiXmlElement* TrainingStrategy::to_XML(void) const
01203 {
01204 std::ostringstream buffer;
01205
01206
01207
01208 TiXmlElement* training_strategy_element = new TiXmlElement("TrainingStrategy");
01209 training_strategy_element->SetAttribute("Version", 4);
01210
01211
01212 {
01213 TiXmlElement* initialization_training_algorithm_type_element = new TiXmlElement("InitializationTrainingAlgorithmType");
01214 training_strategy_element->LinkEndChild(initialization_training_algorithm_type_element);
01215
01216 const std::string new_initialization_training_algorithm_type = write_initialization_training_algorithm_type();
01217
01218 TiXmlText* initialization_training_algorithm_type_text = new TiXmlText(new_initialization_training_algorithm_type.c_str());
01219 initialization_training_algorithm_type_element->LinkEndChild(initialization_training_algorithm_type_text);
01220 }
01221
01222
01223 {
01224 TiXmlElement* main_training_algorithm_type_element = new TiXmlElement("MainTrainingAlgorithmType");
01225 training_strategy_element->LinkEndChild(main_training_algorithm_type_element);
01226
01227 const std::string new_main_training_algorithm_type = write_main_training_algorithm_type();
01228
01229 TiXmlText* main_training_algorithm_type_text = new TiXmlText(new_main_training_algorithm_type.c_str());
01230 main_training_algorithm_type_element->LinkEndChild(main_training_algorithm_type_text);
01231 }
01232
01233
01234 {
01235 TiXmlElement* refinement_training_algorithm_type_element = new TiXmlElement("RefinementTrainingAlgorithmType");
01236 training_strategy_element->LinkEndChild(refinement_training_algorithm_type_element);
01237
01238 const std::string new_refinement_training_algorithm_type = write_refinement_training_algorithm_type();
01239
01240 TiXmlText* refinement_training_algorithm_type_text = new TiXmlText(new_refinement_training_algorithm_type.c_str());
01241 refinement_training_algorithm_type_element->LinkEndChild(refinement_training_algorithm_type_text);
01242 }
01243
01244
01245 {
01246 TiXmlElement* initialization_training_algorithm_flag_element = new TiXmlElement("InitializationTrainingAlgorithmFlag");
01247 training_strategy_element->LinkEndChild(initialization_training_algorithm_flag_element);
01248
01249 buffer.str("");
01250 buffer << initialization_training_algorithm_flag;
01251
01252 TiXmlText* initialization_training_algorithm_flag_text = new TiXmlText(buffer.str().c_str());
01253 initialization_training_algorithm_flag_element->LinkEndChild(initialization_training_algorithm_flag_text);
01254 }
01255
01256
01257 {
01258 TiXmlElement* main_training_algorithm_flag_element = new TiXmlElement("MainTrainingAlgorithmFlag");
01259 training_strategy_element->LinkEndChild(main_training_algorithm_flag_element);
01260
01261 buffer.str("");
01262 buffer << main_training_algorithm_flag;
01263
01264 TiXmlText* main_training_algorithm_flag_text = new TiXmlText(buffer.str().c_str());
01265 main_training_algorithm_flag_element->LinkEndChild(main_training_algorithm_flag_text);
01266 }
01267
01268
01269 {
01270 TiXmlElement* refinement_training_algorithm_flag_element = new TiXmlElement("RefinementTrainingAlgorithmFlag");
01271 training_strategy_element->LinkEndChild(refinement_training_algorithm_flag_element);
01272
01273 buffer.str("");
01274 buffer << refinement_training_algorithm_flag;
01275
01276 TiXmlText* refinement_training_algorithm_flag_text = new TiXmlText(buffer.str().c_str());
01277 refinement_training_algorithm_flag_element->LinkEndChild(refinement_training_algorithm_flag_text);
01278 }
01279
01280
01281
01282 if(initialization_training_algorithm_pointer)
01283 {
01284 TiXmlElement* initialization_training_algorithm_element = initialization_training_algorithm_pointer->to_XML();
01285
01286 training_strategy_element->LinkEndChild(initialization_training_algorithm_element);
01287 }
01288
01289
01290
01291 if(main_training_algorithm_pointer)
01292 {
01293 TiXmlElement* main_training_algorithm_element = main_training_algorithm_pointer->to_XML();
01294
01295 training_strategy_element->LinkEndChild(main_training_algorithm_element);
01296 }
01297
01298
01299
01300 if(refinement_training_algorithm_pointer)
01301 {
01302 TiXmlElement* refinement_training_algorithm_element = refinement_training_algorithm_pointer->to_XML();
01303
01304 training_strategy_element->LinkEndChild(refinement_training_algorithm_element);
01305 }
01306
01307
01308 {
01309 TiXmlElement* display_element = new TiXmlElement("Display");
01310 training_strategy_element->LinkEndChild(display_element);
01311
01312 buffer.str("");
01313 buffer << display;
01314
01315 TiXmlText* display_text = new TiXmlText(buffer.str().c_str());
01316 display_element->LinkEndChild(display_text);
01317 }
01318
01319 return(training_strategy_element);
01320 }
01321
01322
01323
01324
01327
01328 void TrainingStrategy::from_XML(TiXmlElement* training_strategy_element)
01329 {
01330 if(!training_strategy_element)
01331 {
01332 return;
01333 }
01334
01335
01336 {
01337 TiXmlElement* initialization_training_algorithm_type_element = training_strategy_element->FirstChildElement("InitializationTrainingAlgorithmType");
01338
01339 if(initialization_training_algorithm_type_element)
01340 {
01341 const std::string new_initialization_training_algorithm_type = initialization_training_algorithm_type_element->GetText();
01342
01343 try
01344 {
01345 set_initialization_training_algorithm_type(new_initialization_training_algorithm_type);
01346 }
01347 catch(std::exception& e)
01348 {
01349 std::cout << e.what() << std::endl;
01350 }
01351 }
01352 }
01353
01354
01355 {
01356 TiXmlElement* main_training_algorithm_type_element = training_strategy_element->FirstChildElement("MainTrainingAlgorithmType");
01357
01358 if(main_training_algorithm_type_element)
01359 {
01360 const std::string new_main_training_algorithm_type = main_training_algorithm_type_element->GetText();
01361
01362 try
01363 {
01364 set_main_training_algorithm_type(new_main_training_algorithm_type);
01365 }
01366 catch(std::exception& e)
01367 {
01368 std::cout << e.what() << std::endl;
01369 }
01370 }
01371 }
01372
01373
01374 {
01375 TiXmlElement* refinement_training_algorithm_type_element = training_strategy_element->FirstChildElement("RefinementTrainingAlgorithmType");
01376
01377 if(refinement_training_algorithm_type_element)
01378 {
01379 const std::string new_refinement_training_algorithm_type = refinement_training_algorithm_type_element->GetText();
01380
01381 try
01382 {
01383 set_refinement_training_algorithm_type(new_refinement_training_algorithm_type);
01384 }
01385 catch(std::exception& e)
01386 {
01387 std::cout << e.what() << std::endl;
01388 }
01389 }
01390 }
01391
01392
01393 {
01394 TiXmlElement* initialization_training_algorithm_flag_element = training_strategy_element->FirstChildElement("InitializationTrainingAlgorithmFlag");
01395
01396 if(initialization_training_algorithm_flag_element)
01397 {
01398 const std::string new_initialization_training_algorithm_flag = initialization_training_algorithm_flag_element->GetText();
01399
01400 try
01401 {
01402 set_initialization_training_algorithm_flag(new_initialization_training_algorithm_flag != "0");
01403 }
01404 catch(std::exception& e)
01405 {
01406 std::cout << e.what() << std::endl;
01407 }
01408 }
01409 }
01410
01411
01412 {
01413 TiXmlElement* main_training_algorithm_flag_element = training_strategy_element->FirstChildElement("MainTrainingAlgorithmFlag");
01414
01415 if(main_training_algorithm_flag_element)
01416 {
01417 const std::string new_main_training_algorithm_flag = main_training_algorithm_flag_element->GetText();
01418
01419 try
01420 {
01421 set_main_training_algorithm_flag(new_main_training_algorithm_flag != "0");
01422 }
01423 catch(std::exception& e)
01424 {
01425 std::cout << e.what() << std::endl;
01426 }
01427 }
01428 }
01429
01430
01431 {
01432 TiXmlElement* refinement_training_algorithm_flag_element = training_strategy_element->FirstChildElement("RefinementTrainingAlgorithmFlag");
01433
01434 if(refinement_training_algorithm_flag_element)
01435 {
01436 const std::string new_refinement_training_algorithm_flag = refinement_training_algorithm_flag_element->GetText();
01437
01438 try
01439 {
01440 set_refinement_training_algorithm_flag(new_refinement_training_algorithm_flag != "0");
01441 }
01442 catch(std::exception& e)
01443 {
01444 std::cout << e.what() << std::endl;
01445 }
01446 }
01447 }
01448
01449
01450
01451 if(initialization_training_algorithm_pointer)
01452 {
01453 TiXmlElement* element = initialization_training_algorithm_pointer->to_XML();
01454
01455 training_strategy_element->LinkEndChild(element);
01456 }
01457
01458
01459
01460 if(main_training_algorithm_pointer)
01461 {
01462 TiXmlElement* element = main_training_algorithm_pointer->to_XML();
01463
01464 training_strategy_element->LinkEndChild(element);
01465 }
01466
01467
01468
01469 if(refinement_training_algorithm_pointer)
01470 {
01471 TiXmlElement* element = refinement_training_algorithm_pointer->to_XML();
01472
01473 training_strategy_element->LinkEndChild(element);
01474 }
01475
01476
01477 {
01478 TiXmlElement* display_element = training_strategy_element->FirstChildElement("Display");
01479
01480 if(display_element)
01481 {
01482 std::string new_display = display_element->GetText();
01483
01484 try
01485 {
01486 set_display(new_display != "0");
01487 }
01488 catch(std::exception& e)
01489 {
01490 std::cout << e.what() << std::endl;
01491 }
01492 }
01493 }
01494 }
01495
01496
01497
01498
01501
01502 void TrainingStrategy::save(const std::string& filename) const
01503 {
01504 std::ostringstream buffer;
01505
01506 TiXmlDocument document;
01507
01508
01509
01510 TiXmlDeclaration* declaration = new TiXmlDeclaration("1.0", "", "");
01511 document.LinkEndChild(declaration);
01512
01513
01514
01515 TiXmlElement* training_strategy_element = to_XML();
01516 document.LinkEndChild(training_strategy_element);
01517
01518 document.SaveFile(filename.c_str());
01519 }
01520
01521
01522
01523
01527
01528 void TrainingStrategy::load(const std::string& filename)
01529 {
01530 set_default();
01531
01532 std::ostringstream buffer;
01533
01534 TiXmlDocument document(filename.c_str());
01535
01536 if (!document.LoadFile())
01537 {
01538 buffer << "OpenNN Exception: TrainingStrategy class.\n"
01539 << "void load(const std::string&) method.\n"
01540 << "Cannot load XML file " << filename << ".\n";
01541
01542 throw std::logic_error(buffer.str());
01543 }
01544
01545
01546
01547 TiXmlElement* training_algorithm_element = document.FirstChildElement("TrainingStrategy");
01548
01549 if(!training_algorithm_element)
01550 {
01551 buffer << "OpenNN Exception: TrainingStrategy class.\n"
01552 << "void load(const std::string&) method.\n"
01553 << "File " << filename << " is not a valid training algorithm file.\n";
01554
01555 throw std::logic_error(buffer.str());
01556 }
01557 }
01558
01559
01560
01561
01564
01565 void TrainingStrategy::Results::save(const std::string& filename) const
01566 {
01567 std::ofstream file(filename.c_str());
01568
01569 if(initialization_training_algorithm_results_pointer)
01570 {
01571 file << initialization_training_algorithm_results_pointer->to_string();
01572 }
01573
01574 if(main_training_algorithm_results_pointer)
01575 {
01576 file << main_training_algorithm_results_pointer->to_string();
01577 }
01578
01579 if(refinement_training_algorithm_results_pointer)
01580 {
01581 file << refinement_training_algorithm_results_pointer->to_string();
01582 }
01583
01584 file.close();
01585 }
01586
01587 }
01588
01589
01590
01591
01592
01593
01594
01595
01596
01597
01598
01599
01600
01601
01602
01603
01604
01605