laplaceTests.cpp 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. /**
  2. * @file laplaceTests.cpp
  3. * @brief Laplace Approximation Tests
  4. * @author Erik Rodner
  5. * @date 02/17/2010
  6. */
  7. #include "core/imagedisplay/ImageDisplay.h"
  8. #include "core/basics/Config.h"
  9. #include "vislearning/baselib/ICETools.h"
  10. #include "vislearning/cbaselib/LabeledSet.h"
  11. #include "vislearning/classifier/kernelclassifier/LikelihoodFunction.h"
  12. #include "vislearning/classifier/kernelclassifier/LaplaceApproximation.h"
  13. #include "vislearning/classifier/kernelclassifier/KCGPLaplace.h"
  14. #include "vislearning/math/kernels/KernelData.h"
  15. #include "vislearning/math/kernels/Kernel.h"
  16. #include "vislearning/math/kernels/KernelRBF.h"
  17. #include "vislearning/math/kernels/KernelExp.h"
  18. using namespace std;
  19. using namespace OBJREC;
  20. using namespace NICE;
  21. /**
  22. Laplace Approximation Tests
  23. */
  24. int main (int argc, char **argv)
  25. {
  26. #ifdef __GLIBCXX__
  27. std::set_terminate(__gnu_cxx::__verbose_terminate_handler);
  28. #endif
  29. Config conf ( argc, argv );
  30. LabeledSetVector train;
  31. train.read ( conf.gS("main", "set" ), LabeledSetVector::FILEFORMAT_NOINDEX );
  32. LOOP_ALL_NONCONST ( train )
  33. {
  34. EACH_NONCONST (classno, v );
  35. v[0] /= 300.0;
  36. v[1] /= 300.0;
  37. }
  38. double rbf_sigma = conf.gD("main", "rbf_sigma", -2.0 );
  39. KernelRBF kernelFunction ( rbf_sigma, 0.0 );
  40. //KernelExp kernelFunction ( rbf_sigma, 0.0, 0.0 );
  41. KernelClassifier *classifier = new KCGPLaplace ( &conf, &kernelFunction );
  42. classifier->teach( train );
  43. FloatImage predictions ( 100, 100 );
  44. for ( uint i = 0 ; i < (uint)predictions.height(); i++ )
  45. for ( uint j = 0 ; j < (uint)predictions.width(); j++ )
  46. {
  47. double yy = i/(double)predictions.height();
  48. double xx = j/(double)predictions.width();
  49. Vector vec (2);
  50. vec[0] = xx;
  51. vec[1] = yy;
  52. if ( train.dimension() == 3 )
  53. vec.append(1.0);
  54. ClassificationResult r = classifier->classify ( vec );
  55. predictions.setPixel(j,i,r.scores[1]);
  56. }
  57. ColorImage img;
  58. ICETools::convertToRGB ( predictions, img );
  59. LOOP_ALL(train)
  60. {
  61. EACH(classno,vec);
  62. int xx = vec[0]*predictions.width();
  63. int yy = vec[1]*predictions.height();
  64. img.setPixel(xx,yy,0, 255*classno );
  65. img.setPixel(xx,yy,1, 255*classno );
  66. img.setPixel(xx,yy,2, 255*classno );
  67. }
  68. showImage ( img );
  69. return 0;
  70. }