DTBOblique.h 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. /**
  2. * @file DTBOblique.h
  3. * @brief oblique decision tree
  4. * @author Sven Sickert
  5. * @date 10/15/2014
  6. */
  7. #ifndef DTBOBLIQUEINCLUDE
  8. #define DTBOBLIQUEINCLUDE
  9. #include "core/vector/VectorT.h"
  10. #include "core/vector/MatrixT.h"
  11. #include "core/basics/Config.h"
  12. #include "DecisionTreeBuilder.h"
  13. #include "vislearning/cbaselib/CachedExample.h"
  14. namespace OBJREC {
  15. /** random oblique decision tree */
  16. class DTBOblique : public DecisionTreeBuilder
  17. {
  18. protected:
  19. /////////////////////////
  20. /////////////////////////
  21. // PROTECTED VARIABLES //
  22. /////////////////////////
  23. /////////////////////////
  24. /** Amount of randomly chosen thresholds */
  25. int random_split_tests;
  26. /** Maximum allowed depth of a tree */
  27. int max_depth;
  28. /* Minimum amount of features in a leaf node */
  29. int min_examples;
  30. /** Minimum entropy to continue with splitting */
  31. int minimum_entropy;
  32. /** Minimum information gain to continue with splitting */
  33. int minimum_information_gain;
  34. /** Whether to use shannon entropy or not */
  35. bool use_shannon_entropy;
  36. /** Whether to save indices in leaves or not */
  37. bool save_indices;
  38. /** Regularization parameter */
  39. double lambdaInit;
  40. /////////////////////////
  41. /////////////////////////
  42. // PROTECTED METHODS //
  43. /////////////////////////
  44. /////////////////////////
  45. /**
  46. * @brief get data matrix X and label vector y
  47. * @param fp feature pool
  48. * @param examples all examples of the training
  49. * @param examples_selection indeces of selected example subset
  50. * @param matX data matrix (amountExamples x amountParameters)
  51. * @param vecY label vector (amountExamples)
  52. */
  53. void getDataAndLabel(
  54. const FeaturePool &fp,
  55. const Examples &examples,
  56. const std::vector<int> & examples_selection,
  57. NICE::Matrix &matX,
  58. NICE::Vector &vecY );
  59. /**
  60. * @brief recursive building method
  61. * @param fp feature pool
  62. * @param examples all examples of the training
  63. * @param examples_selection indeces of selected example subset
  64. * @param distribution class distribution in current node
  65. * @param entropy current entropy
  66. * @param maxClassNo maximum class number
  67. * @param depth current depth
  68. * @return Pointer to root/parent node
  69. */
  70. DecisionNode *buildRecursive (
  71. const FeaturePool & fp,
  72. const Examples & examples,
  73. std::vector<int> & examples_selection,
  74. FullVector & distribution,
  75. double entropy,
  76. int maxClassNo,
  77. int depth,
  78. double curLambda );
  79. /**
  80. * @brief compute entropy for left and right child
  81. * @param values feature values
  82. * @param threshold threshold for split
  83. * @param stat_left statistics for left child
  84. * @param stat_right statistics for right child
  85. * @param entropy_left entropy for left child
  86. * @param entropy_right entropy for right child
  87. * @param count_left amount of features in left child
  88. * @param count_right amount of features in right child
  89. * @param maxClassNo maximum class number
  90. * @return whether another split is possible or not
  91. */
  92. bool entropyLeftRight ( const FeatureValuesUnsorted & values,
  93. double threshold,
  94. double* stat_left,
  95. double* stat_right,
  96. double & entropy_left,
  97. double & entropy_right,
  98. double & count_left,
  99. double & count_right,
  100. int maxClassNo );
  101. public:
  102. /** simple constructor */
  103. DTBOblique ( const NICE::Config *conf,
  104. std::string section = "DTBOblique" );
  105. /** simple destructor */
  106. virtual ~DTBOblique();
  107. /**
  108. * @brief initial building method
  109. * @param fp feature pool
  110. * @param examples all examples of the training
  111. * @param maxClassNo maximum class number
  112. * @return Pointer to root/parent node
  113. */
  114. DecisionNode *build ( const FeaturePool &fp,
  115. const Examples &examples,
  116. int maxClassNo );
  117. };
  118. } //namespace
  119. #endif