TestGradientDescent.cpp 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. #ifdef NICE_USELIB_CPPUNIT
  2. #include <string>
  3. #include <exception>
  4. #include <map>
  5. #include "TestGradientDescent.h"
  6. #include "MyCostFunction.h"
  7. using namespace std;
  8. const bool verboseStartEnd = true;
  9. const bool verbose = true;
  10. //const bool verbose = false;
  11. CPPUNIT_TEST_SUITE_REGISTRATION( TestGradientDescent );
  12. void TestGradientDescent::setUp() {
  13. }
  14. void TestGradientDescent::tearDown() {
  15. }
  16. void TestGradientDescent::testGD_1Dim ()
  17. {
  18. if (verboseStartEnd)
  19. std::cerr << "================== TestGradientDescent::testGD_1Dim ===================== " << std::endl;
  20. int dim (1);
  21. CostFunction *func = new MyCostFunction(dim, verbose);
  22. //initial guess: 2.0
  23. optimization::matrix_type initialParams (dim, 1);
  24. initialParams.Set(2.0);
  25. //we use a dimension scale of 1.0
  26. optimization::matrix_type scales (dim, 1);
  27. scales.Set(1.0);
  28. //setup the optimization problem
  29. SimpleOptProblem optProblem ( func, initialParams, scales );
  30. optProblem.setMaximize(false);
  31. GradientDescentOptimizer optimizer;
  32. //we search with step-width of 1.0
  33. optimization::matrix_type searchSteps (dim, 1);
  34. searchSteps[0][0] = 1.0f;
  35. //optimizer.setVerbose(true);
  36. optimizer.setStepSize( searchSteps );
  37. optimizer.setMaxNumIter(true, 1000);
  38. optimizer.setFuncTol(true, 1e-8);
  39. optimizer.optimizeProb ( optProblem );
  40. optimization::matrix_type optimizedParams (optProblem.getAllCurrentParams());
  41. double goal(4.2);
  42. if (verbose)
  43. std::cerr << "1d optimization -- result " << optimizedParams[0][0] << " -- goal: " << goal << std::endl;
  44. CPPUNIT_ASSERT_DOUBLES_EQUAL( optimizedParams[0][0], goal, 1e-4 /* tolerance */);
  45. if (verboseStartEnd)
  46. std::cerr << "================== TestGradientDescent::testGD_1Dim done ===================== " << std::endl;
  47. }
  48. void TestGradientDescent::testGD_2Dim()
  49. {
  50. if (verboseStartEnd)
  51. std::cerr << "================== TestGradientDescent::testGD_2Dim ===================== " << std::endl;
  52. int dim (2);
  53. CostFunction *func = new MyCostFunction(dim, verbose);
  54. //initial guess: 2.0
  55. optimization::matrix_type initialParams (dim, 1);
  56. initialParams.Set(2.0);
  57. //we use a dimension scale of 1.0
  58. optimization::matrix_type scales (dim, 1);
  59. scales.Set(1.0);
  60. //setup the optimization problem
  61. SimpleOptProblem optProblem ( func, initialParams, scales );
  62. optProblem.setMaximize(false);
  63. GradientDescentOptimizer optimizer;
  64. //we search with step-width of 1.0
  65. optimization::matrix_type searchSteps (dim, 1);
  66. searchSteps[0][0] = 1.0f;
  67. searchSteps[1][0] = 1.0f;
  68. //optimizer.setVerbose(true);
  69. optimizer.setStepSize( searchSteps );
  70. optimizer.setMaxNumIter(true, 1000);
  71. optimizer.setFuncTol(true, 1e-8);
  72. optimizer.optimizeProb ( optProblem );
  73. optimization::matrix_type optimizedParams (optProblem.getAllCurrentParams());
  74. double goalFirstDim(4.7);
  75. double goalSecondDim(1.1);
  76. if (verbose)
  77. {
  78. std::cerr << "2d optimization 1st dim-- result " << optimizedParams[0][0] << " -- goal: " << goalFirstDim << std::endl;
  79. std::cerr << "2d optimization 1st dim-- result " << optimizedParams[1][0] << " -- goal: " << goalSecondDim << std::endl;
  80. }
  81. CPPUNIT_ASSERT_DOUBLES_EQUAL( optimizedParams[0][0], goalFirstDim, 1e-4 /* tolerance */);
  82. CPPUNIT_ASSERT_DOUBLES_EQUAL( optimizedParams[1][0], goalSecondDim, 1e-4 /* tolerance */);
  83. if (verboseStartEnd)
  84. std::cerr << "================== TestGradientDescent::testGD_2Dim done ===================== " << std::endl;
  85. }
  86. #endif