SemSegLocal.cpp 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. /**
  2. * @file SemSegLocal.cpp
  3. * @brief semantic segmentation using image patches only
  4. * @author Erik Rodner
  5. * @date 05/08/2008
  6. */
  7. #include <iostream>
  8. #include "SemSegLocal.h"
  9. #include "vislearning/cbaselib/CachedExample.h"
  10. #include "vislearning/classifier/fpclassifier/randomforest/FPCRandomForests.h"
  11. #include "vislearning/features/fpfeatures/PixelPairFeature.h"
  12. #include "SemSegTools.h"
  13. using namespace OBJREC;
  14. using namespace std;
  15. using namespace NICE;
  16. SemSegLocal::SemSegLocal( const Config *conf,
  17. const MultiDataset *md )
  18. : SemanticSegmentation ( conf, &(md->getClassNames("train")) )
  19. {
  20. save_cache = conf->gB("FPCPixel", "save_cache", true );
  21. read_cache = conf->gB("FPCPixel", "read_cache", false );
  22. cache = conf->gS("FPCPixel", "cache", "fpc.data" );
  23. fpc = new FPCRandomForests ( conf, "FPCPixel" );
  24. fpc->setMaxClassNo ( classNames->getMaxClassno() );
  25. if ( read_cache ) {
  26. fprintf (stderr, "LocSSimpleFP:: Reading classifier data from %s\n", cache.c_str() );
  27. fpc->read ( cache );
  28. fprintf (stderr, "LocSSimpleFP:: successfully read\n" );
  29. } else {
  30. train ( conf, md );
  31. }
  32. }
  33. void SemSegLocal::train ( const Config *conf, const MultiDataset *md )
  34. {
  35. Examples examples;
  36. vector<CachedExample *> imgexamples;
  37. SemSegTools::collectTrainingExamples (
  38. conf,
  39. "FPCPixel", // config section for grid settings
  40. *((*md)["train"]),
  41. *classNames,
  42. examples,
  43. imgexamples );
  44. assert ( examples.size() > 0 );
  45. FeaturePool fp;
  46. PixelPairFeature hf (conf);
  47. hf.explode ( fp );
  48. fpc->train ( fp, examples );
  49. // clean up memory !!
  50. for ( vector<CachedExample *>::iterator i = imgexamples.begin();
  51. i != imgexamples.end();
  52. i++ )
  53. delete ( *i );
  54. if ( save_cache ) {
  55. fpc->save ( cache );
  56. }
  57. fp.destroy();
  58. }
  59. SemSegLocal::~SemSegLocal()
  60. {
  61. if ( fpc != NULL )
  62. delete fpc;
  63. }
  64. void SemSegLocal::semanticseg ( CachedExample *ce,
  65. NICE::Image & segresult,
  66. NICE::MultiChannelImageT<double> & probabilities )
  67. {
  68. // for speed optimization
  69. FPCRandomForests *fpcrf = dynamic_cast<FPCRandomForests *> ( fpc );
  70. int xsize, ysize;
  71. ce->getImageSize ( xsize, ysize );
  72. probabilities.reInit ( xsize, ysize, classNames->getMaxClassno()+1, true/*allocMem*/ );
  73. segresult.resize(xsize, ysize);
  74. Example pce ( ce, 0, 0 );
  75. long int offset = 0;
  76. for ( int y = 0 ; y < ysize ; y++ )
  77. for ( int x = 0 ; x < xsize ; x++,offset++ )
  78. {
  79. pce.x = x ; pce.y = y;
  80. ClassificationResult r = fpcrf->classify ( pce );
  81. segresult.setPixel(x,y,r.classno);
  82. for ( int i = 0 ; i < (int)probabilities.numChannels; i++ )
  83. probabilities.data[i][offset] = r.scores[i];
  84. }
  85. }