DTBOblique.h 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. /**
  2. * @file DTBOblique.h
  3. * @brief oblique decision tree
  4. * @author Sven Sickert
  5. * @date 10/15/2014
  6. */
  7. #ifndef DTBOBLIQUEINCLUDE
  8. #define DTBOBLIQUEINCLUDE
  9. #include "core/vector/VectorT.h"
  10. #include "core/vector/MatrixT.h"
  11. #include "core/basics/Config.h"
  12. #include "DecisionTreeBuilder.h"
  13. #include "vislearning/cbaselib/CachedExample.h"
  14. namespace OBJREC {
  15. /** random oblique decision tree */
  16. class DTBOblique : public DecisionTreeBuilder
  17. {
  18. protected:
  19. /////////////////////////
  20. /////////////////////////
  21. // PROTECTED VARIABLES //
  22. /////////////////////////
  23. /////////////////////////
  24. /** Whether to use shannon entropy or not */
  25. bool useShannonEntropy;
  26. /** Whether to save indices in leaves or not */
  27. bool saveIndices;
  28. /** Whether to use one-vs-one or one-vs-all for multiclass scenarios */
  29. bool useOneVsOne;
  30. /** Amount of steps for complete search for best threshold */
  31. int splitSteps;
  32. /** Maximum allowed depth of a tree */
  33. int maxDepth;
  34. /* Minimum amount of features in a leaf node */
  35. int minExamples;
  36. /** Regularization type */
  37. int regularizationType;
  38. /** Minimum entropy to continue with splitting */
  39. double minimumEntropy;
  40. /** Minimum information gain to continue with splitting */
  41. double minimumInformationGain;
  42. /** Regularization parameter */
  43. double lambdaInit;
  44. /////////////////////////
  45. /////////////////////////
  46. // PROTECTED METHODS //
  47. /////////////////////////
  48. /////////////////////////
  49. /**
  50. * @brief get data matrix X and label vector y
  51. * @param fp feature pool
  52. * @param examples all examples of the training
  53. * @param examples_selection indeces of selected example subset
  54. * @param matX data matrix (amountExamples x amountParameters)
  55. * @param vecY label vector (amountExamples)
  56. */
  57. void getDataAndLabel(
  58. const FeaturePool &fp,
  59. const Examples &examples,
  60. const std::vector<int> & examples_selection,
  61. NICE::Matrix &matX,
  62. NICE::Vector &vecY );
  63. /**
  64. * @brief return a regularization matrix of size (dimParams)x(dimParams)
  65. * @param X data matrix
  66. * @param XTXreg return regularized X'*X
  67. * @param regOption which kind of regularization
  68. * @param lambda regularization parameter (weigthing)
  69. */
  70. void regularizeDataMatrix (
  71. const NICE::Matrix & X,
  72. NICE::Matrix &XTXreg,
  73. const int regOption,
  74. const double lambda );
  75. /**
  76. * @brief recursive building method
  77. * @param fp feature pool
  78. * @param examples all examples of the training
  79. * @param examples_selection indeces of selected example subset
  80. * @param distribution class distribution in current node
  81. * @param entropy current entropy
  82. * @param maxClassNo maximum class number
  83. * @param depth current depth
  84. * @return Pointer to root/parent node
  85. */
  86. DecisionNode *buildRecursive (
  87. const FeaturePool & fp,
  88. const Examples & examples,
  89. std::vector<int> & examples_selection,
  90. FullVector & distribution,
  91. double entropy,
  92. int maxClassNo,
  93. int depth,
  94. double curLambda );
  95. /**
  96. * @brief compute entropy for left and right child
  97. * @param values feature values
  98. * @param threshold threshold for split
  99. * @param stat_left statistics for left child
  100. * @param stat_right statistics for right child
  101. * @param entropy_left entropy for left child
  102. * @param entropy_right entropy for right child
  103. * @param count_left amount of features in left child
  104. * @param count_right amount of features in right child
  105. * @param maxClassNo maximum class number
  106. * @return whether another split is possible or not
  107. */
  108. bool entropyLeftRight ( const FeatureValuesUnsorted & values,
  109. double threshold,
  110. double* stat_left,
  111. double* stat_right,
  112. double & entropy_left,
  113. double & entropy_right,
  114. double & count_left,
  115. double & count_right,
  116. int maxClassNo );
  117. public:
  118. /** simple constructor */
  119. DTBOblique ( const NICE::Config *conf,
  120. std::string section = "DTBOblique" );
  121. /** simple destructor */
  122. virtual ~DTBOblique();
  123. /**
  124. * @brief initial building method
  125. * @param fp feature pool
  126. * @param examples all examples of the training
  127. * @param maxClassNo maximum class number
  128. * @return Pointer to root/parent node
  129. */
  130. DecisionNode *build ( const FeaturePool &fp,
  131. const Examples &examples,
  132. int maxClassNo );
  133. };
  134. } //namespace
  135. #endif