toyExample.cpp 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. /**
  2. * @file toyExample.cpp
  3. * @brief just a toy tool
  4. * @author Erik Rodner
  5. * @date 04/07/2009
  6. */
  7. #include <iomanip>
  8. #include <core/imagedisplay/SimpleSelector.h>
  9. #include <core/imagedisplay/OverlayColors.h>
  10. #include <core/image/CrossT.h>
  11. #include <core/image/CircleT.h>
  12. #include "core/basics/Config.h"
  13. #include "vislearning/baselib/ICETools.h"
  14. #include "vislearning/classifier/genericClassifierSelection.h"
  15. #ifdef NOVISUAL
  16. #include <vislearning/nice_nonvis.h>
  17. #else
  18. #include <vislearning/nice.h>
  19. #endif
  20. using namespace OBJREC;
  21. using namespace NICE;
  22. using namespace std;
  23. #ifndef NOVISUAL
  24. void selectTrainingSet ( LabeledSetVector & train, NICE::Image & img, int numClasses,
  25. bool addBias = false )
  26. {
  27. vector<int> colors;
  28. vector<CoordT<double> > points;
  29. NICE::selectColoredPoints ( img, points, colors, "Select some points!", numClasses );
  30. int k = 0;
  31. for ( vector<CoordT<double> >::const_iterator i = points.begin();
  32. i != points.end(); i++,k++ )
  33. {
  34. NICE::Vector feature ( addBias ? 3 : 2 );
  35. feature[0] = i->x;
  36. feature[1] = i->y;
  37. if ( addBias )
  38. feature[2] = 1.0;
  39. train.add ( colors[k]-1, feature );
  40. }
  41. }
  42. #endif
  43. void markBoundary ( const NICE::Image & imgclassno, NICE::Image & mark )
  44. {
  45. for ( int y = 0 ; y < imgclassno.height(); y++ )
  46. for ( int x = 0 ; x < imgclassno.width(); x++ )
  47. {
  48. int val = imgclassno.getPixel(x,y);
  49. bool boundary = false;
  50. for ( int i = -1 ; (i <= 1) && (!boundary) ; i++ )
  51. for ( int j = -1 ; (j <= 1) && (!boundary) ; j++ )
  52. {
  53. int xn = x + i;
  54. int yn = y + j;
  55. if ( (xn<0) || (yn<0) || (xn>=imgclassno.width()) || (yn>=imgclassno.height()) )
  56. continue;
  57. int valn = imgclassno.getPixel(xn,yn);
  58. if ( valn != val )
  59. boundary = true;
  60. }
  61. if ( boundary )
  62. mark.setPixel(x,y,1);
  63. }
  64. }
  65. /**
  66. just a toy tool
  67. */
  68. int main (int argc, char **argv)
  69. {
  70. std::set_terminate(__gnu_cxx::__verbose_terminate_handler);
  71. Config conf ( argc, argv );
  72. conf.store(cout);
  73. std::string classifier_type = conf.gS("main", "classifier", "sparse_logistic_regression");
  74. fprintf (stderr, "Classifier type: %s\n", classifier_type.c_str() );
  75. VecClassifier *vec_classifier = GenericClassifierSelection::selectVecClassifier ( &conf, classifier_type );
  76. if ( vec_classifier == NULL )
  77. {
  78. fprintf (stderr, "This classifier type is unknown !\n");
  79. exit(-1);
  80. }
  81. conf.sS("VCSVMLight", "normalization_type", "none" );
  82. int xsize = conf.gI("main", "xsize", 300 );
  83. int ysize = conf.gI("main", "ysize", 300 );
  84. int numClasses = conf.gI("main", "numClasses", 2 );
  85. vec_classifier->setMaxClassNo(numClasses);
  86. bool addBias = conf.gB("main", "addbias", "false" );
  87. NICE::Image img (xsize, ysize);
  88. NICE::Image mark (img);
  89. mark.set(0);
  90. img.set(255);
  91. LabeledSetVector train;
  92. std::string trainsetcache = conf.gS("main", "trainset", "");
  93. bool readtrainset = conf.gB("main", "readtrainset", false);
  94. bool selectManually = conf.gB("main", "select", true);
  95. if ( selectManually )
  96. {
  97. #ifdef NOVISUAL
  98. fprintf (stderr, "toyExample: visual manual selection needs ICE visualization\n");
  99. #else
  100. selectTrainingSet ( train, img, numClasses, addBias );
  101. #endif
  102. }
  103. if ( readtrainset && (trainsetcache.size() > 0 ) )
  104. {
  105. train.read ( trainsetcache, LabeledSetVector::FILEFORMAT_NOINDEX );
  106. }
  107. LOOP_ALL(train)
  108. {
  109. EACH(classno,x);
  110. if ( classno == 0 ) {
  111. Cross cross ( Coord( (int)(x[0]), (int)(x[1]) ), 10 );
  112. mark.draw ( cross, classno+2 );
  113. } else {
  114. Circle circle ( Coord( (int)(x[0]), (int)(x[1]) ), 10 );
  115. mark.draw ( circle, classno+2 );
  116. }
  117. }
  118. bool writetrainset = conf.gB("main", "writetrainset", false);
  119. if ( writetrainset && (trainsetcache.size() > 0) )
  120. train.save ( trainsetcache, LabeledSetVector::FILEFORMAT_NOINDEX);
  121. if ( train.count() <= 0 )
  122. {
  123. fprintf (stderr, "toyExample: size of the training set is zero!\n");
  124. exit(-1);
  125. }
  126. fprintf (stderr, "Dimension of the training set: %d\n", train.dimension() );
  127. vec_classifier->teach ( train );
  128. vec_classifier->finishTeaching();
  129. NICE::FloatImage imgd (img.width(), img.height());
  130. NICE::Image imgclassno (img);
  131. for ( int y = 0 ; y < img.height(); y++ )
  132. for ( int x = 0 ; x < img.width(); x++ )
  133. {
  134. NICE::Vector example ( addBias ? 3 : 2 );
  135. example[0] = x;
  136. example[1] = y;
  137. if ( addBias )
  138. example[2] = 1.0;
  139. ClassificationResult r = vec_classifier->classify(example);
  140. if ( numClasses == 2 )
  141. {
  142. imgd.setPixel(x,y,(r.scores.get(1)));
  143. } else {
  144. imgd.setPixel(x, y, r.classno / (double)(numClasses-1) );
  145. }
  146. imgclassno.setPixel(x,y,r.classno);
  147. }
  148. markBoundary ( imgclassno, mark );
  149. floatToGrayScaled ( imgd, &img );
  150. showImageOverlay ( img, mark );
  151. string resultimg = conf.gS("main", "resultimg", "");
  152. if ( resultimg.size() > 0 )
  153. {
  154. ColorImage result;
  155. grayToRGB( img, &result );
  156. for ( uint y = 0 ; y < result.height() ; y++ )
  157. for ( uint x = 0 ; x < result.width() ; x++ )
  158. if ( mark.getPixel(x,y) > 0 ) {
  159. result.setPixel(x,y,0,overlayColorTable[mark.getPixel(x,y)][0]);
  160. result.setPixel(x,y,1,overlayColorTable[mark.getPixel(x,y)][1]);
  161. result.setPixel(x,y,2,overlayColorTable[mark.getPixel(x,y)][2]);
  162. }
  163. ImageFile imgf ( resultimg );
  164. imgf.writer(&result);
  165. }
  166. showImageOverlay ( imgclassno, imgclassno );
  167. return 0;
  168. }