00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018 #include <iostream>
00019 #include <fstream>
00020 #include <algorithm>
00021 #include <functional>
00022 #include <limits>
00023 #include <cmath>
00024 #include <ctime>
00025
00026
00027
00028 #include "training_rate_algorithm.h"
00029
00030
00031
00032 #include "../../parsers/tinyxml/tinyxml.h"
00033
00034 namespace OpenNN
00035 {
00036
00037
00038
00042
00043 TrainingRateAlgorithm::TrainingRateAlgorithm(void)
00044 : performance_functional_pointer(NULL)
00045 {
00046 set_default();
00047 }
00048
00049
00050
00051
00056
00057 TrainingRateAlgorithm::TrainingRateAlgorithm(PerformanceFunctional* new_performance_functional_pointer)
00058 : performance_functional_pointer(new_performance_functional_pointer)
00059 {
00060 set_default();
00061 }
00062
00063
00064
00065
00070
00071 TrainingRateAlgorithm::TrainingRateAlgorithm(TiXmlElement* training_rate_algorithm_element)
00072 : performance_functional_pointer(NULL)
00073 {
00074 from_XML(training_rate_algorithm_element);
00075 }
00076
00077
00078
00079
00081
00082 TrainingRateAlgorithm::~TrainingRateAlgorithm(void)
00083 {
00084 }
00085
00086
00087
00088
00089
00090
00093
00094 PerformanceFunctional* TrainingRateAlgorithm::get_performance_functional_pointer(void)
00095 {
00096 return(performance_functional_pointer);
00097 }
00098
00099
00100
00101
00103
00104 const TrainingRateAlgorithm::TrainingRateMethod& TrainingRateAlgorithm::get_training_rate_method(void) const
00105 {
00106 return(training_rate_method);
00107 }
00108
00109
00110
00111
00113
00114 std::string TrainingRateAlgorithm::write_training_rate_method(void) const
00115 {
00116 switch(training_rate_method)
00117 {
00118 case Fixed:
00119 {
00120 return("Fixed");
00121 }
00122 break;
00123
00124 case GoldenSection:
00125 {
00126 return("GoldenSection");
00127 }
00128 break;
00129
00130 case BrentMethod:
00131 {
00132 return("BrentMethod");
00133 }
00134 break;
00135
00136 default:
00137 {
00138 std::ostringstream buffer;
00139
00140 buffer << "OpenNN Exception: TrainingRateAlgorithm class.\n"
00141 << "std::string get_training_rate_method(void) const method.\n"
00142 << "Unknown training rate method.\n";
00143
00144 throw std::logic_error(buffer.str().c_str());
00145 }
00146 break;
00147 }
00148 }
00149
00150
00151
00152
00154
00155 const double& TrainingRateAlgorithm::get_first_training_rate(void) const
00156 {
00157 return(first_training_rate);
00158 }
00159
00160
00161
00162
00164
00165 const double& TrainingRateAlgorithm::get_bracketing_factor(void) const
00166 {
00167 return(bracketing_factor);
00168 }
00169
00170
00171
00172
00174
00175 const double& TrainingRateAlgorithm::get_training_rate_tolerance(void) const
00176 {
00177 return(training_rate_tolerance);
00178 }
00179
00180
00181
00182
00185
00186 const double& TrainingRateAlgorithm::get_warning_training_rate(void) const
00187 {
00188 return(warning_training_rate);
00189 }
00190
00191
00192
00193
00196
00197 const double& TrainingRateAlgorithm::get_error_training_rate(void) const
00198 {
00199 return(error_training_rate);
00200 }
00201
00202
00203
00204
00207
00208 const bool& TrainingRateAlgorithm::get_display(void) const
00209 {
00210 return(display);
00211 }
00212
00213
00214
00215
00218
00219 void TrainingRateAlgorithm::set(void)
00220 {
00221 performance_functional_pointer = NULL;
00222 set_default();
00223 }
00224
00225
00226
00227
00231
00232 void TrainingRateAlgorithm::set(PerformanceFunctional* new_performance_functional_pointer)
00233 {
00234 performance_functional_pointer = new_performance_functional_pointer;
00235 set_default();
00236 }
00237
00238
00239
00240
00242
00243 void TrainingRateAlgorithm::set_default(void)
00244 {
00245
00246
00247 training_rate_method = BrentMethod;
00248
00249
00250
00251 bracketing_factor = 1.5;
00252
00253 first_training_rate = 1.0e-2;
00254 training_rate_tolerance = 1.0e-6;
00255
00256 warning_training_rate = 1.0e6;
00257
00258 error_training_rate = 1.0e9;
00259
00260
00261
00262 display = true;
00263 }
00264
00265
00266
00267
00270
00271 void TrainingRateAlgorithm::set_performance_functional_pointer(PerformanceFunctional* new_performance_functional_pointer)
00272 {
00273 performance_functional_pointer = new_performance_functional_pointer;
00274 }
00275
00276
00277
00278
00281
00282 void TrainingRateAlgorithm::set_training_rate_method(const TrainingRateAlgorithm::TrainingRateMethod& new_training_rate_method)
00283 {
00284 training_rate_method = new_training_rate_method;
00285 }
00286
00287
00288
00289
00292
00293 void TrainingRateAlgorithm::set_training_rate_method(const std::string& new_training_rate_method)
00294 {
00295 if(new_training_rate_method == "Fixed")
00296 {
00297 training_rate_method = Fixed;
00298 }
00299 else if(new_training_rate_method == "GoldenSection")
00300 {
00301 training_rate_method = GoldenSection;
00302 }
00303 else if(new_training_rate_method == "BrentMethod")
00304 {
00305 training_rate_method = BrentMethod;
00306 }
00307 else
00308 {
00309 std::ostringstream buffer;
00310
00311 buffer << "OpenNN Exception: TrainingRateAlgorithm class.\n"
00312 << "void set_method(const std::string&) method.\n"
00313 << "Unknown training rate method: " << new_training_rate_method << ".\n";
00314
00315 throw std::logic_error(buffer.str().c_str());
00316 }
00317 }
00318
00319
00320
00321
00324
00325 void TrainingRateAlgorithm::set_first_training_rate(const double& new_first_training_rate)
00326 {
00327
00328
00329 #ifdef _DEBUG
00330
00331 if(new_first_training_rate < 0.0)
00332 {
00333 std::ostringstream buffer;
00334
00335 buffer << "OpenNN Exception: TrainingRateAlgorithm class.\n"
00336 << "void set_first_training_rate(const double&) method.\n"
00337 << "First training rate must be equal or greater than 0.\n";
00338
00339 throw std::logic_error(buffer.str().c_str());
00340 }
00341
00342 #endif
00343
00344
00345
00346 first_training_rate = new_first_training_rate;
00347 }
00348
00349
00350
00351
00354
00355 void TrainingRateAlgorithm::set_bracketing_factor(const double& new_bracketing_factor)
00356 {
00357
00358
00359 #ifdef _DEBUG
00360
00361 if(new_bracketing_factor < 0.0)
00362 {
00363 std::ostringstream buffer;
00364
00365 buffer << "OpenNN Exception: TrainingRateAlgorithm class.\n"
00366 << "void set_bracketing_factor(const double&) method.\n"
00367 << "Bracketing factor must be equal or greater than 0.\n";
00368
00369 throw std::logic_error(buffer.str().c_str());
00370 }
00371
00372 #endif
00373
00374 bracketing_factor = new_bracketing_factor;
00375 }
00376
00377
00378
00379
00382
00383 void TrainingRateAlgorithm::set_training_rate_tolerance(const double& new_training_rate_tolerance)
00384 {
00385
00386
00387 #ifdef _DEBUG
00388
00389 if(new_training_rate_tolerance < 0.0)
00390 {
00391 std::ostringstream buffer;
00392
00393 buffer << "OpenNN Exception: TrainingRateAlgorithm class.\n"
00394 << "void set_training_rate_tolerance(const double&) method.\n"
00395 << "Tolerance must be equal or greater than 0.\n";
00396
00397 throw std::logic_error(buffer.str().c_str());
00398 }
00399
00400 #endif
00401
00402
00403
00404 training_rate_tolerance = new_training_rate_tolerance;
00405 }
00406
00407
00408
00409
00413
00414 void TrainingRateAlgorithm::set_warning_training_rate(const double& new_warning_training_rate)
00415 {
00416
00417
00418 #ifdef _DEBUG
00419
00420 if(new_warning_training_rate < 0.0)
00421 {
00422 std::ostringstream buffer;
00423
00424 buffer << "OpenNN Exception: TrainingRateAlgorithm class.\n"
00425 << "void set_warning_training_rate(const double&) method.\n"
00426 << "Warning training rate must be equal or greater than 0.\n";
00427
00428 throw std::logic_error(buffer.str().c_str());
00429 }
00430
00431 #endif
00432
00433 warning_training_rate = new_warning_training_rate;
00434 }
00435
00436
00437
00438
00442
00443 void TrainingRateAlgorithm::set_error_training_rate(const double& new_error_training_rate)
00444 {
00445
00446
00447 #ifdef _DEBUG
00448
00449 if(new_error_training_rate < 0.0)
00450 {
00451 std::ostringstream buffer;
00452
00453 buffer << "OpenNN Exception: TrainingRateAlgorithm class.\n"
00454 << "void set_error_training_rate(const double&) method.\n"
00455 << "Error training rate must be equal or greater than 0.\n";
00456
00457 throw std::logic_error(buffer.str().c_str());
00458 }
00459
00460 #endif
00461
00462
00463
00464 error_training_rate = new_error_training_rate;
00465 }
00466
00467
00468
00469
00474
00475 void TrainingRateAlgorithm::set_display(const bool& new_display)
00476 {
00477 display = new_display;
00478 }
00479
00480
00481
00482
00488
00489 Vector<double> TrainingRateAlgorithm::calculate_directional_point(const double& performance, const Vector<double>& training_direction, const double& initial_training_rate) const
00490 {
00491 #ifdef _DEBUG
00492
00493 if(performance_functional_pointer == NULL)
00494 {
00495 std::ostringstream buffer;
00496
00497 buffer << "OpenNN Error: TrainingRateAlgorithm class.\n"
00498 << "Vector<double> calculate_directional_point(const double&, const Vector<double>&, const double&) const method.\n"
00499 << "Pointer to performance functional is NULL.\n";
00500
00501 throw std::logic_error(buffer.str().c_str());
00502 }
00503
00504 #endif
00505
00506 #ifdef _DEBUG
00507
00508 NeuralNetwork* neural_network_pointer = performance_functional_pointer->get_neural_network_pointer();
00509
00510 if(neural_network_pointer == NULL)
00511 {
00512 std::ostringstream buffer;
00513
00514 buffer << "OpenNN Error: TrainingRateAlgorithm class.\n"
00515 << "Vector<double> calculate_directional_point(const double&, const Vector<double>&, const double&) const method.\n"
00516 << "Pointer to neural network is NULL.\n";
00517
00518 throw std::logic_error(buffer.str().c_str());
00519 }
00520
00521 #endif
00522
00523
00524 switch(training_rate_method)
00525 {
00526 case TrainingRateAlgorithm::Fixed:
00527 {
00528 return(calculate_fixed_directional_point(performance, training_direction, initial_training_rate));
00529 }
00530 break;
00531
00532 case TrainingRateAlgorithm::GoldenSection:
00533 {
00534 return(calculate_golden_section_directional_point(performance, training_direction, initial_training_rate));
00535 }
00536 break;
00537
00538 case TrainingRateAlgorithm::BrentMethod:
00539 {
00540 return(calculate_Brent_method_directional_point(performance, training_direction, initial_training_rate));
00541 }
00542 break;
00543
00544 default:
00545 {
00546 std::ostringstream buffer;
00547
00548 buffer << "OpenNN Exception: TrainingRateAlgorithm class\n"
00549 << "Vector<double> calculate_directional_point(double, const Vector<double>&, double) const method.\n"
00550 << "Unknown training rate method.\n";
00551
00552 throw std::logic_error(buffer.str().c_str());
00553 }
00554 }
00555 }
00556
00557
00558
00559
00566
00567 Vector< Vector<double> > TrainingRateAlgorithm::calculate_bracketing_training_rate(const double& performance, const Vector<double>& training_direction, const double& initial_training_rate) const
00568 {
00569
00570
00571 Vector<double> A(2);
00572 A[0] = 0;
00573 A[1] = performance;
00574
00575 Vector<double> U(2);
00576 U[0] = initial_training_rate;
00577 U[1] = performance_functional_pointer->calculate_directional_performance(training_direction, U[0]);
00578
00579 Vector<double> B = U;
00580
00581 while(A[1] <= U[1])
00582 {
00583 B = U;
00584
00585 U[0] /= bracketing_factor;
00586 U[1] = performance_functional_pointer->calculate_directional_performance(training_direction, U[0]);
00587
00588 if(U[0] < training_rate_tolerance)
00589 {
00590
00591
00592
00593
00594
00595
00596
00597
00598 Vector< Vector<double> > bracketing_training_rate(3, A);
00599
00600 return(bracketing_training_rate);
00601 }
00602 }
00603
00604 while(U[1] >= B[1])
00605 {
00606 B[0] *= bracketing_factor;
00607 B[1] = performance_functional_pointer->calculate_directional_performance(training_direction, B[0]);
00608
00609
00610 if(B[0] > error_training_rate)
00611 {
00612 std::ostringstream buffer;
00613
00614 buffer << "OpenNN Warning: TrainingRateAlgorithm class.\n"
00615 << "Vector<double> calculate_bracketing_training_rate(double, const Vector<double>&, double) const method\n."
00616 << "Right point is " << B[0] << "\n";
00617
00618 throw std::logic_error(buffer.str().c_str());
00619 }
00620 }
00621
00622 if((A[0] >= U[0] || U[0] >= B[0]) || (A[1] <= U[1] || U[1] >= B[1]))
00623 {
00624 std::ostringstream buffer;
00625
00626 buffer << "OpenNN Exception: TrainingRateAlgorithm class.\n"
00627 << "Vector<double> calculate_bracketing_training_rate(double, const Vector<double>&, double) const method\n."
00628 << "Uncorrect triplet:\n"
00629 << "A = (" << A[0] << "," << A[1] << ")\n"
00630 << "U = (" << U[0] << "," << U[1] << ")\n"
00631 << "B = (" << B[0] << "," << B[1] << ")\n";
00632
00633 throw std::logic_error(buffer.str());
00634 }
00635
00636
00637
00638 Vector< Vector<double> > bracketing_training_rate(3);
00639 bracketing_training_rate[0] = A;
00640 bracketing_training_rate[1] = U;
00641 bracketing_training_rate[2] = B;
00642
00643 return(bracketing_training_rate);
00644 }
00645
00646
00647
00648
00652
00653 Vector<double> TrainingRateAlgorithm::calculate_fixed_directional_point(const double&, const Vector<double>& training_direction, const double&) const
00654 {
00655 Vector<double> directional_point(2);
00656
00657 directional_point[0] = first_training_rate;
00658 directional_point[1] = performance_functional_pointer->calculate_directional_performance(training_direction, first_training_rate);
00659
00660 return(directional_point);
00661 }
00662
00663
00664
00665
00671
00672 Vector<double> TrainingRateAlgorithm::calculate_golden_section_directional_point
00673 (const double& performance, const Vector<double>& training_direction, const double& initial_training_rate) const
00674 {
00675 std::ostringstream buffer;
00676
00677
00678
00679 try
00680 {
00681 Vector< Vector<double> > bracketing_training_rate = calculate_bracketing_training_rate(performance, training_direction, initial_training_rate);
00682
00683 Vector<double> A = bracketing_training_rate[0];
00684 Vector<double> U = bracketing_training_rate[1];
00685 Vector<double> B = bracketing_training_rate[2];
00686
00687 if(A == B)
00688 {
00689 return(A);
00690 }
00691
00692 Vector<double> V(2);
00693
00694
00695
00696 do
00697 {
00698 V[0] = calculate_golden_section_training_rate(A, U, B);
00699 V[1] = performance_functional_pointer->calculate_directional_performance(training_direction, V[0]);
00700
00701
00702
00703 if(V[0] < U[0] && V[1] >= U[1])
00704 {
00705 A = V;
00706
00707
00708 }
00709 else if(V[0] < U[0] && V[1] <= U[1])
00710 {
00711
00712 B = U;
00713 U = V;
00714 }
00715 else if(V[0] > U[0] && V[1] >= U[1])
00716 {
00717
00718 B = V;
00719
00720 }
00721 else if(V[0] > U[0] && V[1] <= U[1])
00722 {
00723 A = U;
00724
00725 U = V;
00726 }
00727 else if(V[0] == U[0])
00728 {
00729 buffer << "OpenNN Exception: TrainingRateAlgorithm class.\n"
00730 << "Vector<double> calculate_golden_section_directional_point(double, const Vector<double>, double) const method.\n"
00731 << "Both interior points have the same ordinate.\n";
00732
00733 std::cout << buffer.str() << std::endl;
00734
00735 break;
00736 }
00737 else
00738 {
00739 buffer << "OpenNN Exception: TrainingRateAlgorithm class.\n"
00740 << "Vector<double> calculate_golden_section_directional_point(double, const Vector<double>, double) const method.\n"
00741 << "Unknown set:\n"
00742 << "A = (" << A[0] << "," << A[1] << ")\n"
00743 << "B = (" << B[0] << "," << B[1] << ")\n"
00744 << "U = (" << U[0] << "," << U[1] << ")\n"
00745 << "V = (" << V[0] << "," << V[1] << ")\n";
00746
00747 throw std::logic_error(buffer.str());
00748 }
00749
00750
00751
00752 if(A[1] < U[1] || U[1] > B[1])
00753 {
00754 buffer << "OpenNN Exception: TrainingRateAlgorithm class.\n"
00755 << "Vector<double> calculate_golden_section_directional_point(double, const Vector<double>, double) const method.\n"
00756 << "Triplet does not satisfy minimum condition:\n"
00757 << "A = (" << A[0] << "," << A[1] << ")\n"
00758 << "B = (" << B[0] << "," << B[1] << ")\n"
00759 << "U = (" << U[0] << "," << U[1] << ")\n";
00760
00761 throw std::logic_error(buffer.str());
00762 }
00763
00764 }while(B[0] - A[0] > training_rate_tolerance);
00765
00766 return(U);
00767 }
00768 catch(std::range_error& e)
00769 {
00770 std::cerr << e.what() << std::endl;
00771
00772 Vector<double> A(2);
00773 A[0] = 0.0;
00774 A[1] = performance;
00775
00776 return(A);
00777 }
00778 catch(std::logic_error& e)
00779 {
00780 std::cerr << e.what() << std::endl;
00781
00782 Vector<double> X(2);
00783 X[0] = first_training_rate;
00784 X[1] = performance_functional_pointer->calculate_directional_performance(training_direction, X[0]);
00785
00786 if(X[1] > performance)
00787 {
00788 X[0] = 0.0;
00789 X[1] = 0.0;
00790 }
00791
00792 return(X);
00793 }
00794 }
00795
00796
00797
00798
00804
00805 Vector<double> TrainingRateAlgorithm::calculate_Brent_method_directional_point
00806 (const double& performance, const Vector<double>& training_direction, const double& initial_training_rate) const
00807 {
00808 std::ostringstream buffer;
00809
00810
00811
00812 try
00813 {
00814 Vector< Vector<double> > bracketing_training_rate = calculate_bracketing_training_rate(performance, training_direction, initial_training_rate);
00815
00816 Vector<double> A = bracketing_training_rate[0];
00817 Vector<double> U = bracketing_training_rate[1];
00818 Vector<double> B = bracketing_training_rate[2];
00819
00820 if(A == B)
00821 {
00822 return(A);
00823 }
00824
00825 Vector<double> V(2);
00826
00827
00828
00829 while(B[0] - A[0] > training_rate_tolerance)
00830 {
00831 try
00832 {
00833 V[0] = calculate_Brent_method_training_rate(A, U, B);
00834 }
00835 catch(std::logic_error&)
00836 {
00837 V[0] = calculate_golden_section_training_rate(A, U, B);
00838 }
00839
00840
00841
00842 V[1] = performance_functional_pointer->calculate_directional_performance(training_direction, V[0]);
00843
00844
00845
00846 if(V[0] < U[0] && V[1] >= U[1])
00847 {
00848 A = V;
00849
00850
00851 }
00852 else if(V[0] < U[0] && V[1] <= U[1])
00853 {
00854
00855 B = U;
00856 U = V;
00857 }
00858 else if(V[0] > U[0] && V[1] >= U[1])
00859 {
00860
00861 B = V;
00862
00863 }
00864 else if(V[0] > U[0] && V[1] <= U[1])
00865 {
00866 A = U;
00867
00868 U = V;
00869 }
00870 else if(V[0] == U[0])
00871 {
00872 buffer << "OpenNN Exception: TrainingRateAlgorithm class.\n"
00873 << "Vector<double> calculate_Brent_method_directional_point(double, const Vector<double>, double) const method.\n"
00874 << "Both interior points have the same ordinate.\n";
00875
00876 break;
00877 }
00878 else
00879 {
00880 buffer << "OpenNN Exception: TrainingRateAlgorithm class.\n"
00881 << "Vector<double> calculate_Brent_method_directional_point(double, const Vector<double>, double) const method.\n"
00882 << "Unknown set:\n"
00883 << "A = (" << A[0] << "," << A[1] << ")\n"
00884 << "B = (" << B[0] << "," << B[1] << ")\n"
00885 << "U = (" << U[0] << "," << U[1] << ")\n"
00886 << "V = (" << V[0] << "," << V[1] << ")\n";
00887
00888 throw std::logic_error(buffer.str());
00889 }
00890
00891
00892
00893 if(A[1] < U[1] || U[1] > B[1])
00894 {
00895 std::ostringstream buffer;
00896
00897 buffer << "OpenNN Exception: TrainingRateAlgorithm class.\n"
00898 << "Vector<double> calculate_Brent_method_directional_point(double, const Vector<double>, double) const method.\n"
00899 << "Triplet does not satisfy minimum condition:\n"
00900 << "A = (" << A[0] << "," << A[1] << ")\n"
00901 << "B = (" << B[0] << "," << B[1] << ")\n"
00902 << "U = (" << U[0] << "," << U[1] << ")\n";
00903
00904 throw std::logic_error(buffer.str());
00905 }
00906 }
00907
00908 return(U);
00909 }
00910 catch(std::range_error& e)
00911 {
00912 std::cerr << e.what() << std::endl;
00913
00914 Vector<double> A(2);
00915 A[0] = 0.0;
00916 A[1] = performance;
00917
00918 return(A);
00919 }
00920 catch(std::logic_error& e)
00921 {
00922 std::cerr << e.what() << std::endl;
00923
00924 Vector<double> X(2);
00925 X[0] = first_training_rate;
00926 X[1] = performance_functional_pointer->calculate_directional_performance(training_direction, X[0]);
00927
00928 if(X[1] > performance)
00929 {
00930 X[0] = 0.0;
00931 X[1] = 0.0;
00932 }
00933
00934 return(X);
00935 }
00936 }
00937
00938
00939
00940
00945
00946 double TrainingRateAlgorithm::calculate_golden_section_training_rate(const Vector<double>& A, const Vector<double>& U, const Vector<double>& B) const
00947 {
00948
00949
00950 if(U[0] < A[0] + 0.5*(B[0] - A[0]))
00951 {
00952 return(A[0] + 0.618*(B[0] - A[0]));
00953 }
00954 else
00955 {
00956 return(A[0] + 0.382*(B[0] - A[0]));
00957 }
00958 }
00959
00960
00961
00962
00967
00968 double TrainingRateAlgorithm::calculate_Brent_method_training_rate(const Vector<double>& A, const Vector<double>& B, const Vector<double>& C) const
00969 {
00970 std::ostringstream buffer;
00971
00972 const double c = -(A[1]*(B[0]-C[0]) + B[1]*(C[0]-A[0]) + C[1]*(A[0]-B[0]))/((A[0]-B[0])*(B[0]-C[0])*(C[0]-A[0]));
00973
00974 if(c == 0)
00975 {
00976 buffer << "OpenNN Exception: TrainingRateAlgorithm class.\n"
00977 << "double calculate_Brent_method_training_rate(Vector<double>&, Vector<double>&, Vector<double>&) const method.\n"
00978 << "Parabola cannot be constructed.\n";
00979
00980 throw std::logic_error(buffer.str());
00981 }
00982 else if(c < 0)
00983 {
00984 buffer << "OpenNN Exception: TrainingRateAlgorithm class.\n"
00985 << "double calculate_Brent_method_training_rate(Vector<double>&, Vector<double>&, Vector<double>&) const method.\n"
00986 << "Parabola does not have a minimum but a maximum.\n";
00987
00988 throw std::logic_error(buffer.str());
00989 }
00990
00991 const double b = (A[1]*(B[0]*B[0]-C[0]*C[0]) + B[1]*(C[0]*C[0]-A[0]*A[0]) + C[1]*(A[0]*A[0]-B[0]*B[0]))/((A[0]-B[0])*(B[0]-C[0])*(C[0]-A[0]));
00992
00993 const double Brent_method_training_rate = -b/(2.0*c);
00994
00995 if(Brent_method_training_rate <= A[0] || Brent_method_training_rate >= C[0])
00996 {
00997 buffer << "OpenNN Exception: TrainingRateAlgorithm class.\n"
00998 << "double calculate_parabola_minimal_training_rate(Vector<double>&, Vector<double>&, Vector<double>&) const method.\n"
00999 << "Brent method training rate is not inside interval.\n"
01000 << "Interval: (" << A[0] << "," << B[0] << ")\n"
01001 << "Brent method training rate: " << Brent_method_training_rate << std::endl;
01002
01003 throw std::logic_error(buffer.str());
01004 }
01005
01006 return(Brent_method_training_rate);
01007 }
01008
01009
01010
01011
01014
01015 TiXmlElement* TrainingRateAlgorithm::to_XML(void) const
01016 {
01017 std::ostringstream buffer;
01018
01019
01020
01021 TiXmlElement* training_rate_algorithm_element = new TiXmlElement("TrainingRateAlgorithm");
01022 training_rate_algorithm_element->SetAttribute("Version", 4);
01023
01024
01025
01026 TiXmlElement* training_rate_method_element = new TiXmlElement("TrainingRateMethod");
01027 training_rate_algorithm_element->LinkEndChild(training_rate_method_element);
01028
01029 TiXmlText* training_rate_method_text = new TiXmlText(write_training_rate_method().c_str());
01030 training_rate_method_element->LinkEndChild(training_rate_method_text);
01031
01032
01033
01034 TiXmlElement* bracketing_factor_element = new TiXmlElement("BracketingFactor");
01035 training_rate_algorithm_element->LinkEndChild(bracketing_factor_element);
01036
01037 buffer.str("");
01038 buffer << bracketing_factor;
01039
01040 TiXmlText* bracketing_factor_text = new TiXmlText(buffer.str().c_str());
01041 bracketing_factor_element->LinkEndChild(bracketing_factor_text);
01042
01043
01044
01045 TiXmlElement* first_training_rate_element = new TiXmlElement("FirstTrainingRate");
01046 training_rate_algorithm_element->LinkEndChild(first_training_rate_element);
01047
01048 buffer.str("");
01049 buffer << first_training_rate;
01050
01051 TiXmlText* first_training_rate_text = new TiXmlText(buffer.str().c_str());
01052 first_training_rate_element->LinkEndChild(first_training_rate_text);
01053
01054
01055
01056 TiXmlElement* training_rate_tolerance_element = new TiXmlElement("TrainingRateTolerance");
01057 training_rate_algorithm_element->LinkEndChild(training_rate_tolerance_element);
01058
01059 buffer.str("");
01060 buffer << training_rate_tolerance;
01061
01062 TiXmlText* training_rate_tolerance_text = new TiXmlText(buffer.str().c_str());
01063 training_rate_tolerance_element->LinkEndChild(training_rate_tolerance_text);
01064
01065
01066
01067 TiXmlElement* warning_training_rate_element = new TiXmlElement("WarningTrainingRate");
01068 training_rate_algorithm_element->LinkEndChild(warning_training_rate_element);
01069
01070 buffer.str("");
01071 buffer << warning_training_rate;
01072
01073 TiXmlText* warning_training_rate_text = new TiXmlText(buffer.str().c_str());
01074 warning_training_rate_element->LinkEndChild(warning_training_rate_text);
01075
01076
01077
01078 TiXmlElement* error_training_rate_element = new TiXmlElement("ErrorTrainingRate");
01079 training_rate_algorithm_element->LinkEndChild(error_training_rate_element);
01080
01081 buffer.str("");
01082 buffer << error_training_rate;
01083
01084 TiXmlText* error_training_rate_text = new TiXmlText(buffer.str().c_str());
01085 error_training_rate_element->LinkEndChild(error_training_rate_text);
01086
01087
01088
01089 TiXmlElement* display_element = new TiXmlElement("Display");
01090 training_rate_algorithm_element->LinkEndChild(display_element);
01091
01092 buffer.str("");
01093 buffer << display;
01094
01095 TiXmlText* display_text = new TiXmlText(buffer.str().c_str());
01096 display_element->LinkEndChild(display_text);
01097
01098 return(training_rate_algorithm_element);
01099 }
01100
01101
01102
01103
01107
01108 void TrainingRateAlgorithm::from_XML(TiXmlElement* training_rate_algorithm_element)
01109 {
01110
01111
01112 TiXmlElement* training_rate_method_element = training_rate_algorithm_element->FirstChildElement("TrainingRateMethod");
01113
01114 if(training_rate_method_element)
01115 {
01116 std::string new_training_rate_method = training_rate_method_element->GetText();
01117
01118 try
01119 {
01120 set_training_rate_method(new_training_rate_method);
01121 }
01122 catch(std::exception& e)
01123 {
01124 std::cout << e.what() << std::endl;
01125 }
01126 }
01127
01128
01129
01130 TiXmlElement* bracketing_factor_element = training_rate_algorithm_element->FirstChildElement("BracketingFactor");
01131
01132 if(bracketing_factor_element)
01133 {
01134 double new_bracketing_factor = atof(bracketing_factor_element->GetText());
01135
01136 try
01137 {
01138 set_bracketing_factor(new_bracketing_factor);
01139 }
01140 catch(std::exception& e)
01141 {
01142 std::cout << e.what() << std::endl;
01143 }
01144 }
01145
01146
01147
01148 TiXmlElement* first_training_rate_element = training_rate_algorithm_element->FirstChildElement("FirstTrainingRate");
01149
01150 if(first_training_rate_element)
01151 {
01152 double new_first_training_rate = atof(first_training_rate_element->GetText());
01153
01154 try
01155 {
01156 set_first_training_rate(new_first_training_rate);
01157 }
01158 catch(std::exception& e)
01159 {
01160 std::cout << e.what() << std::endl;
01161 }
01162 }
01163
01164
01165
01166 TiXmlElement* training_rate_tolerance_element = training_rate_algorithm_element->FirstChildElement("TrainingRateTolerance");
01167
01168 if(training_rate_tolerance_element)
01169 {
01170 double new_training_rate_tolerance = atof(training_rate_tolerance_element->GetText());
01171
01172 try
01173 {
01174 set_training_rate_tolerance(new_training_rate_tolerance);
01175 }
01176 catch(std::exception& e)
01177 {
01178 std::cout << e.what() << std::endl;
01179 }
01180 }
01181
01182
01183
01184 TiXmlElement* warning_training_rate_element = training_rate_algorithm_element->FirstChildElement("WarningTrainingRate");
01185
01186 if(warning_training_rate_element)
01187 {
01188 double new_warning_training_rate = atof(warning_training_rate_element->GetText());
01189
01190 try
01191 {
01192 set_warning_training_rate(new_warning_training_rate);
01193 }
01194 catch(std::exception& e)
01195 {
01196 std::cout << e.what() << std::endl;
01197 }
01198 }
01199
01200
01201
01202 TiXmlElement* error_training_rate_element = training_rate_algorithm_element->FirstChildElement("ErrorTrainingRate");
01203
01204 if(error_training_rate_element)
01205 {
01206 double new_error_training_rate = atof(error_training_rate_element->GetText());
01207
01208 try
01209 {
01210 set_error_training_rate(new_error_training_rate);
01211 }
01212 catch(std::exception& e)
01213 {
01214 std::cout << e.what() << std::endl;
01215 }
01216 }
01217
01218
01219
01220 TiXmlElement* display_element = training_rate_algorithm_element->FirstChildElement("Display");
01221
01222 if(display_element)
01223 {
01224 std::string new_display = display_element->GetText();
01225
01226 try
01227 {
01228 set_display(new_display != "0");
01229 }
01230 catch(std::exception& e)
01231 {
01232 std::cout << e.what() << std::endl;
01233 }
01234 }
01235 }
01236
01237 }
01238
01239
01240
01241
01242
01243
01244
01245
01246
01247
01248
01249
01250
01251
01252
01253
01254
01255