DTBRandomOblique.h 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. /**
  2. * @file DTBRandomOblique.h
  3. * @brief random oblique decision tree
  4. * @author Sven Sickert
  5. * @date 10/15/2014
  6. */
  7. #ifndef DTBRANDOMOBLIQUEINCLUDE
  8. #define DTBRANDOMOBLIQUEINCLUDE
  9. #include "core/vector/VectorT.h"
  10. #include "core/vector/MatrixT.h"
  11. #include "core/basics/Config.h"
  12. #include "DecisionTreeBuilder.h"
  13. #include "vislearning/cbaselib/CachedExample.h"
  14. namespace OBJREC {
  15. /** random oblique decision tree */
  16. class DTBRandomOblique : public DecisionTreeBuilder
  17. {
  18. protected:
  19. /////////////////////////
  20. /////////////////////////
  21. // PROTECTED VARIABLES //
  22. /////////////////////////
  23. /////////////////////////
  24. /** Amount of randomly chosen thresholds */
  25. int random_split_tests;
  26. /** Amount of randomly chosen features */
  27. int random_features;
  28. /** Maximum allowed depth of a tree */
  29. int max_depth;
  30. /* Minimum amount of features in a leaf node */
  31. int min_examples;
  32. /** Minimum entropy to continue with splitting */
  33. int minimum_entropy;
  34. /** Minimum information gain to continue with splitting */
  35. int minimum_information_gain;
  36. /** Whether to use shannon entropy or not */
  37. bool use_shannon_entropy;
  38. /** Whether to save indices in leaves or not */
  39. bool save_indices;
  40. /////////////////////////
  41. /////////////////////////
  42. // PROTECTED METHODS //
  43. /////////////////////////
  44. /////////////////////////
  45. /**
  46. * @brief recursive building method
  47. * @param fp feature pool
  48. * @param examples all examples of the training
  49. * @param examples_selection indeces of selected example subset
  50. * @param distribution class distribution in current node
  51. * @param entropy current entropy
  52. * @param maxClassNo maximum class number
  53. * @param depth current depth
  54. * @return Pointer to root/parent node
  55. */
  56. DecisionNode *buildRecursive ( const FeaturePool & fp,
  57. const Examples & examples,
  58. std::vector<int> & examples_selection,
  59. FullVector & distribution,
  60. double entropy,
  61. int maxClassNo,
  62. int depth );
  63. /**
  64. * @brief compute entropy for left and right child
  65. * @param values feature values
  66. * @param threshold threshold for split
  67. * @param stat_left statistics for left child
  68. * @param stat_right statistics for right child
  69. * @param entropy_left entropy for left child
  70. * @param entropy_right entropy for right child
  71. * @param count_left amount of features in left child
  72. * @param count_right amount of features in right child
  73. * @param maxClassNo maximum class number
  74. * @return whether another split is possible or not
  75. */
  76. bool entropyLeftRight ( const FeatureValuesUnsorted & values,
  77. double threshold,
  78. double* stat_left,
  79. double* stat_right,
  80. double & entropy_left,
  81. double & entropy_right,
  82. double & count_left,
  83. double & count_right,
  84. int maxClassNo );
  85. public:
  86. /** simple constructor */
  87. DTBRandomOblique ( const NICE::Config *conf,
  88. std::string section = "DTBRandomOblique" );
  89. /** simple destructor */
  90. virtual ~DTBRandomOblique();
  91. /**
  92. * @brief initial building method
  93. * @param fp feature pool
  94. * @param examples all examples of the training
  95. * @param maxClassNo maximum class number
  96. * @return Pointer to root/parent node
  97. */
  98. DecisionNode *build ( const FeaturePool &fp,
  99. const Examples &examples,
  100. int maxClassNo );
  101. };
  102. } //namespace
  103. #endif