laplaceTests.cpp 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. /**
  2. * @file laplaceTests.cpp
  3. * @brief Laplace Approximation Tests
  4. * @author Erik Rodner
  5. * @date 02/17/2010
  6. */
  7. #include "core/imagedisplay/ImageDisplay.h"
  8. #include "core/basics/Config.h"
  9. #include "vislearning/baselib/ICETools.h"
  10. #include "vislearning/cbaselib/LabeledSet.h"
  11. #include "vislearning/classifier/kernelclassifier/LikelihoodFunction.h"
  12. #include "vislearning/classifier/kernelclassifier/LaplaceApproximation.h"
  13. #include "vislearning/classifier/kernelclassifier/KCGPLaplace.h"
  14. #include "vislearning/math/kernels/KernelData.h"
  15. #include "vislearning/math/kernels/Kernel.h"
  16. #include "vislearning/math/kernels/KernelRBF.h"
  17. #include "vislearning/math/kernels/KernelExp.h"
  18. using namespace std;
  19. using namespace OBJREC;
  20. using namespace NICE;
  21. /**
  22. Laplace Approximation Tests
  23. */
  24. int main (int argc, char **argv)
  25. {
  26. std::set_terminate(__gnu_cxx::__verbose_terminate_handler);
  27. Config conf ( argc, argv );
  28. LabeledSetVector train;
  29. train.read ( conf.gS("main", "set" ), LabeledSetVector::FILEFORMAT_NOINDEX );
  30. LOOP_ALL_NONCONST ( train )
  31. {
  32. EACH_NONCONST (classno, v );
  33. v[0] /= 300.0;
  34. v[1] /= 300.0;
  35. }
  36. double rbf_sigma = conf.gD("main", "rbf_sigma", -2.0 );
  37. KernelRBF kernelFunction ( rbf_sigma, 0.0 );
  38. //KernelExp kernelFunction ( rbf_sigma, 0.0, 0.0 );
  39. KernelClassifier *classifier = new KCGPLaplace ( &conf, &kernelFunction );
  40. classifier->teach( train );
  41. FloatImage predictions ( 100, 100 );
  42. for ( uint i = 0 ; i < (uint)predictions.height(); i++ )
  43. for ( uint j = 0 ; j < (uint)predictions.width(); j++ )
  44. {
  45. double yy = i/(double)predictions.height();
  46. double xx = j/(double)predictions.width();
  47. Vector vec (2);
  48. vec[0] = xx;
  49. vec[1] = yy;
  50. if ( train.dimension() == 3 )
  51. vec.append(1.0);
  52. ClassificationResult r = classifier->classify ( vec );
  53. predictions.setPixel(j,i,r.scores[1]);
  54. }
  55. ColorImage img;
  56. ICETools::convertToRGB ( predictions, img );
  57. LOOP_ALL(train)
  58. {
  59. EACH(classno,vec);
  60. int xx = vec[0]*predictions.width();
  61. int yy = vec[1]*predictions.height();
  62. img.setPixel(xx,yy,0, 255*classno );
  63. img.setPixel(xx,yy,1, 255*classno );
  64. img.setPixel(xx,yy,2, 255*classno );
  65. }
  66. showImage ( img );
  67. return 0;
  68. }