toyExampleUnsupervisedGP.cpp 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. /**
  2. * @file toyExample.cpp
  3. * @brief just a toy tool
  4. * @author Erik Rodner
  5. * @date 04/07/2009
  6. */
  7. #include <iomanip>
  8. #include <core/imagedisplay/SimpleSelector.h>
  9. #include <core/image/CrossT.h>
  10. #include "core/basics/Config.h"
  11. #include "vislearning/baselib/ICETools.h"
  12. #include "vislearning/regression/gpregression/RegGaussianProcess.h"
  13. #include "vislearning/math/kernels/KernelRBF.h"
  14. #include "core/vector/VectorT.h"
  15. #include "core/vector/MatrixT.h"
  16. #include "core/image/ImageT.h"
  17. #include "core/imagedisplay/ImageDisplay.h"
  18. using namespace OBJREC;
  19. using namespace NICE;
  20. using namespace std;
  21. #ifndef NOVISUAL
  22. void selectTrainingSet ( VVector & train, Vector & labels, NICE::Image & img)
  23. {
  24. vector<int> colors;
  25. vector<CoordT<double> > points;
  26. NICE::selectColoredPoints ( img, points, colors, "Select some points!", 3 );
  27. int k = 0;
  28. for ( vector<CoordT<double> >::const_iterator i = points.begin();
  29. i != points.end(); i++,k++ )
  30. {
  31. NICE::Vector feature ( 2 );
  32. feature[0] = i->x / img.width();
  33. feature[1] = i->y / img.height();
  34. train.push_back ( feature );
  35. if ( colors[k] == 1 )
  36. colors[k] = -1;
  37. else if ( colors[k] == 2 )
  38. colors[k] = 1;
  39. else if ( colors[k] == 3 )
  40. colors[k] = 0;
  41. labels.append ( colors[k] );
  42. }
  43. }
  44. #endif
  45. void markBoundary ( const NICE::Image & imgclassno, NICE::Image & mark )
  46. {
  47. for ( int y = 0 ; y < imgclassno.height(); y++ )
  48. for ( int x = 0 ; x < imgclassno.width(); x++ )
  49. {
  50. int val = imgclassno.getPixel(x,y);
  51. bool boundary = false;
  52. for ( int i = -1 ; (i <= 1) && (!boundary) ; i++ )
  53. for ( int j = -1 ; (j <= 1) && (!boundary) ; j++ )
  54. {
  55. int xn = x + i;
  56. int yn = y + j;
  57. if ( (xn<0) || (yn<0) || (xn>=imgclassno.width()) || (yn>=imgclassno.height()) )
  58. continue;
  59. int valn = imgclassno.getPixel(xn,yn);
  60. if ( valn != val )
  61. boundary = true;
  62. }
  63. if ( boundary )
  64. mark.setPixel(x,y,1);
  65. }
  66. }
  67. /**
  68. just a toy tool
  69. */
  70. int main (int argc, char **argv)
  71. {
  72. #ifndef __clang__
  73. #ifndef __llvm__
  74. std::set_terminate(__gnu_cxx::__verbose_terminate_handler);
  75. #endif
  76. #endif
  77. Config conf ( argc, argv );
  78. conf.store(cout);
  79. int xsize = conf.gI("main", "xsize", 300 );
  80. int ysize = conf.gI("main", "ysize", 300 );
  81. NICE::Image img (xsize, ysize);
  82. img.set(255);
  83. NICE::Image mark (img);
  84. mark.set(0);
  85. VVector train;
  86. Vector labels;
  87. std::string trainsetcache = conf.gS("main", "trainset", "");
  88. bool readtrainset = conf.gB("main", "readtrainset", false);
  89. if ( readtrainset && (trainsetcache.size() > 0 ) )
  90. {
  91. ifstream ifs ( trainsetcache.c_str(), ios::in );
  92. if ( !ifs.good() )
  93. fthrow(IOException, "Unable to read training data from " << trainsetcache << "." );
  94. ifs >> labels;
  95. train.restore ( ifs, VVector::FILEFORMAT_LINE );
  96. ifs.close ();
  97. cerr << "Labels: " << labels.size() << " // Examples " << train.size() << endl;
  98. }
  99. int k = 0;
  100. for ( VVector::const_iterator i = train.begin();
  101. i != train.end(); i++,k++ )
  102. {
  103. double classno = labels[k];
  104. const Vector & x = *i;
  105. Cross cross ( Coord( (int)(x[0]*mark.width()), (int)(x[1]*mark.height()) ), 10 );
  106. if ( classno < 0 )
  107. mark.draw ( cross, 1 );
  108. else if ( classno > 0 )
  109. mark.draw ( cross, 2 );
  110. else
  111. mark.draw ( cross, 3 );
  112. }
  113. bool selectManually = conf.gB("main", "select", true);
  114. if ( selectManually )
  115. {
  116. #ifdef NOVISUAL
  117. fprintf (stderr, "toyExample: visual manual selection needs ICE visualization\n");
  118. #else
  119. selectTrainingSet ( train, labels, img );
  120. #endif
  121. }
  122. bool writetrainset = conf.gB("main", "writetrainset", false);
  123. if ( writetrainset && (trainsetcache.size() > 0) )
  124. {
  125. ofstream ofs ( trainsetcache.c_str(), ios::out );
  126. if ( !ofs.good() )
  127. fthrow(IOException, "Unable to write training data to " << trainsetcache << "." );
  128. ofs << labels;
  129. ofs << endl;
  130. train.store ( ofs, VVector::FILEFORMAT_LINE );
  131. ofs.close ();
  132. }
  133. if ( train.size() <= 0 )
  134. {
  135. fthrow(Exception, "Size of the training set is zero!");
  136. }
  137. // do something
  138. KernelRBF kernelFunction ( conf.gD("main", "loggamma", 0.0) );
  139. RegressionAlgorithm *regression = new RegGaussianProcess ( &conf, &kernelFunction );
  140. cerr << labels << endl;
  141. regression->teach ( train, labels );
  142. NICE::FloatImage imgd (img.width(), img.height());
  143. NICE::Image imgclassno (img);
  144. for ( int y = 0 ; y < img.height(); y++ )
  145. for ( int x = 0 ; x < img.width(); x++ )
  146. {
  147. NICE::Vector example ( 2 );
  148. example[0] = x / (double)img.width();
  149. example[1] = y / (double)img.height();
  150. double value = regression->predict ( example );
  151. imgd.setPixel(x,y, 1.0 / ( 1.0 + exp(-value) ));
  152. imgclassno.setPixel(x,y, ( value < 0 ) ? 1 : 2);
  153. }
  154. markBoundary ( imgclassno, mark );
  155. Image imgScore ( img.width(), img.height() );
  156. floatToGrayScaled ( imgd, &imgScore );
  157. #ifndef NOVISUAL
  158. showImageOverlay ( img, mark );
  159. showImageOverlay ( imgScore, mark );
  160. showImageOverlay ( imgclassno, imgclassno );
  161. #endif
  162. return 0;
  163. }