SemSegContextTree.h 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. /**
  2. * @file SemSegContextTree.h
  3. * @brief Context Trees -> Combination of decision tree and context information
  4. * @author Björn Fröhlich
  5. * @date 29.11.2011
  6. */
  7. #ifndef SemSegContextTreeINCLUDE
  8. #define SemSegContextTreeINCLUDE
  9. #include "SemanticSegmentation.h"
  10. #include <objrec/math/mathbase/VVector.h>
  11. #include "objrec/features/localfeatures/LFColorWeijer.h"
  12. namespace OBJREC {
  13. /** Localization system */
  14. class SemSegContextTree : public SemanticSegmentation
  15. {
  16. protected:
  17. class Node
  18. {
  19. public:
  20. /** probabilities for each class */
  21. vector<double> probs;
  22. /** left child node */
  23. int left;
  24. /** right child node */
  25. int right;
  26. /** position of feat for decision */
  27. int feat;
  28. /** decision stamp */
  29. double decision;
  30. /** is the node a leaf or not */
  31. bool isleaf;
  32. /** distribution in current node */
  33. vector<double> dist;
  34. /** depth of the node in the tree */
  35. int depth;
  36. /** simple constructor */
  37. Node():left(-1),right(-1),feat(-1), decision(-1.0), isleaf(false){}
  38. /** standard constructor */
  39. Node(int _left, int _right, int _feat, double _decision):left(_left),right(_right),feat(_feat), decision(_decision),isleaf(false){}
  40. };
  41. /** store features */
  42. VVector currentfeats;
  43. /** store the positions of the features */
  44. VVector positions;
  45. /** tree -> saved as vector of nodes */
  46. vector<Node> tree;
  47. /** local features */
  48. LFColorWeijer *lfcw;
  49. /** distance between features */
  50. int grid;
  51. /** maximum samples for tree */
  52. int maxSamples;
  53. /** count samples per label */
  54. map<int,int> labelcounter;
  55. /** map of labels */
  56. map<int,int> labelmap;
  57. /** map of labels inverse*/
  58. map<int,int> labelmapback;
  59. /** scalefactor for balancing for each class */
  60. vector<double> a;
  61. /** the minimum number of features allowed in a leaf */
  62. int minFeats;
  63. /** maximal depth of tree */
  64. int maxDepth;
  65. public:
  66. /** simple constructor */
  67. SemSegContextTree( const Config *conf, const MultiDataset *md );
  68. /** simple destructor */
  69. virtual ~SemSegContextTree();
  70. /**
  71. * test a single image
  72. * @param ce input data
  73. * @param segresult segmentation results
  74. * @param probabilities probabilities for each pixel
  75. */
  76. void semanticseg ( CachedExample *ce, NICE::Image & segresult, GenericImage<double> & probabilities );
  77. /**
  78. * the main training method
  79. * @param md training data
  80. */
  81. void train ( const MultiDataset *md );
  82. /**
  83. * compute best split for current settings
  84. * @param feats features
  85. * @param currentfeats matrix with current node for each feature
  86. * @param labels labels for each feature
  87. * @param node current node
  88. * @param splitfeat output feature position
  89. * @param splitval
  90. */
  91. void getBestSplit(const vector<vector<vector<vector<double> > > > &feats, vector<vector<vector<int> > > &currentfeats,const vector<vector<vector<int> > > &labels, int node, int &splitfeat, double &splitval);
  92. };
  93. } // namespace
  94. #endif