SemSegLocal.cpp 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. /**
  2. * @file SemSegLocal.cpp
  3. * @brief semantic segmentation using image patches only
  4. * @author Erik Rodner
  5. * @date 05/08/2008
  6. */
  7. #include <iostream>
  8. #include "SemSegLocal.h"
  9. #include "vislearning/cbaselib/CachedExample.h"
  10. #include "vislearning/classifier/fpclassifier/randomforest/FPCRandomForests.h"
  11. #include "vislearning/features/fpfeatures/PixelPairFeature.h"
  12. #include "SemSegTools.h"
  13. using namespace OBJREC;
  14. using namespace std;
  15. using namespace NICE;
  16. SemSegLocal::SemSegLocal ( const Config *conf,
  17. const MultiDataset *md )
  18. : SemanticSegmentation ( conf, & ( md->getClassNames ( "train" ) ) )
  19. {
  20. save_cache = conf->gB ( "FPCPixel", "save_cache", true );
  21. read_cache = conf->gB ( "FPCPixel", "read_cache", false );
  22. cache = conf->gS ( "FPCPixel", "cache", "fpc.data" );
  23. fpc = new FPCRandomForests ( conf, "FPCPixel" );
  24. fpc->setMaxClassNo ( classNames->getMaxClassno() );
  25. if ( read_cache ) {
  26. fprintf ( stderr, "LocSSimpleFP:: Reading classifier data from %s\n", cache.c_str() );
  27. fpc->read ( cache );
  28. fprintf ( stderr, "LocSSimpleFP:: successfully read\n" );
  29. } else {
  30. train ( conf, md );
  31. }
  32. }
  33. void SemSegLocal::train ( const Config *conf, const MultiDataset *md )
  34. {
  35. Examples examples;
  36. vector<CachedExample *> imgexamples;
  37. SemSegTools::collectTrainingExamples (
  38. conf,
  39. "FPCPixel", // config section for grid settings
  40. * ( ( *md ) ["train"] ),
  41. *classNames,
  42. examples,
  43. imgexamples );
  44. assert ( examples.size() > 0 );
  45. FeaturePool fp;
  46. PixelPairFeature hf ( conf );
  47. hf.explode ( fp );
  48. fpc->train ( fp, examples );
  49. // clean up memory !!
  50. for ( vector<CachedExample *>::iterator i = imgexamples.begin();
  51. i != imgexamples.end();
  52. i++ )
  53. delete ( *i );
  54. if ( save_cache ) {
  55. fpc->save ( cache );
  56. }
  57. fp.destroy();
  58. }
  59. SemSegLocal::~SemSegLocal()
  60. {
  61. if ( fpc != NULL )
  62. delete fpc;
  63. }
  64. void SemSegLocal::semanticseg ( CachedExample *ce,
  65. NICE::ImageT<int> & segresult,
  66. NICE::MultiChannelImageT<double> & probabilities )
  67. {
  68. // for speed optimization
  69. FPCRandomForests *fpcrf = dynamic_cast<FPCRandomForests *> ( fpc );
  70. int xsize, ysize;
  71. ce->getImageSize ( xsize, ysize );
  72. probabilities.reInit ( xsize, ysize, classNames->getMaxClassno() + 1 );
  73. segresult.resize ( xsize, ysize );
  74. Example pce ( ce, 0, 0 );
  75. long int offset = 0;
  76. for ( int y = 0 ; y < ysize ; y++ )
  77. for ( int x = 0 ; x < xsize ; x++, offset++ )
  78. {
  79. pce.x = x ;
  80. pce.y = y;
  81. ClassificationResult r = fpcrf->classify ( pce );
  82. segresult.setPixel ( x, y, r.classno );
  83. for ( int i = 0 ; i < ( int ) probabilities.channels(); i++ )
  84. probabilities[i](x,y) = r.scores[i];
  85. }
  86. }