DTBOblique.h 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. /**
  2. * @file DTBOblique.h
  3. * @brief oblique decision tree
  4. * @author Sven Sickert
  5. * @date 10/15/2014
  6. */
  7. #ifndef DTBOBLIQUEINCLUDE
  8. #define DTBOBLIQUEINCLUDE
  9. #include "core/vector/VectorT.h"
  10. #include "core/vector/MatrixT.h"
  11. #include "core/basics/Config.h"
  12. #include "DecisionTreeBuilder.h"
  13. #include "vislearning/cbaselib/CachedExample.h"
  14. namespace OBJREC {
  15. struct SplitInfo {
  16. double threshold;
  17. double informationGain;
  18. double entropyLeft;
  19. double entropyRight;
  20. double *distLeft;
  21. double *distRight;
  22. NICE::Vector params;
  23. };
  24. /** random oblique decision tree */
  25. class DTBOblique : public DecisionTreeBuilder
  26. {
  27. protected:
  28. /////////////////////////
  29. /////////////////////////
  30. // PROTECTED VARIABLES //
  31. /////////////////////////
  32. /////////////////////////
  33. /** Whether to use shannon entropy or not */
  34. bool useShannonEntropy;
  35. /** Whether to save indices in leaves or not */
  36. bool saveIndices;
  37. /** Whether to use one-vs-one or one-vs-all for multiclass scenarios */
  38. bool useOneVsOne;
  39. /** Amount of steps for complete search for best threshold */
  40. int splitSteps;
  41. /** Maximum allowed depth of a tree */
  42. int maxDepth;
  43. /* Minimum amount of features in a leaf node */
  44. int minExamples;
  45. /** Regularization type */
  46. int regularizationType;
  47. /** Minimum entropy to continue with splitting */
  48. double minimumEntropy;
  49. /** Minimum information gain to continue with splitting */
  50. double minimumInformationGain;
  51. /** Regularization parameter */
  52. double lambdaInit;
  53. /////////////////////////
  54. /////////////////////////
  55. // PROTECTED METHODS //
  56. /////////////////////////
  57. /////////////////////////
  58. /**
  59. * @brief get data matrix X and label vector y
  60. * @param fp feature pool
  61. * @param examples all examples of the training
  62. * @param examples_selection indeces of selected example subset
  63. * @param matX data matrix (amountExamples x amountParameters)
  64. * @param vecY label vector (amountExamples)
  65. */
  66. void getDataAndLabel(
  67. const FeaturePool &fp,
  68. const Examples &examples,
  69. const std::vector<int> & examples_selection,
  70. NICE::Matrix &matX,
  71. NICE::Vector &vecY );
  72. /**
  73. * @brief return a regularization matrix of size (dimParams)x(dimParams)
  74. * @param X data matrix
  75. * @param XTXreg return regularized X'*X
  76. * @param regOption which kind of regularization
  77. * @param lambda regularization parameter (weigthing)
  78. */
  79. void regularizeDataMatrix (
  80. const NICE::Matrix & X,
  81. NICE::Matrix &XTXreg,
  82. const int regOption,
  83. const double lambda );
  84. /**
  85. * @brief find best threshold for current splitting
  86. * @param values feature values
  87. * @param bestSplitInfo struct including best split information
  88. * @param e entropy before split
  89. * @param maxClassNo maximum class number
  90. */
  91. void findBestSplitThreshold (
  92. FeatureValuesUnsorted & values,
  93. SplitInfo & bestSplitInfo,
  94. const NICE::Vector & beta,
  95. const double & e,
  96. const int & maxClassNo );
  97. /**
  98. * @brief recursive building method
  99. * @param fp feature pool
  100. * @param examples all examples of the training
  101. * @param examples_selection indeces of selected example subset
  102. * @param distribution class distribution in current node
  103. * @param entropy current entropy
  104. * @param maxClassNo maximum class number
  105. * @param depth current depth
  106. * @return Pointer to root/parent node
  107. */
  108. DecisionNode *buildRecursive (
  109. const FeaturePool & fp,
  110. const Examples & examples,
  111. std::vector<int> & examples_selection,
  112. FullVector & distribution,
  113. double entropy,
  114. int maxClassNo,
  115. int depth,
  116. double curLambda );
  117. /**
  118. * @brief compute entropy for left and right child
  119. * @param values feature values
  120. * @param threshold threshold for split
  121. * @param stat_left statistics for left child
  122. * @param stat_right statistics for right child
  123. * @param entropy_left entropy for left child
  124. * @param entropy_right entropy for right child
  125. * @param count_left amount of features in left child
  126. * @param count_right amount of features in right child
  127. * @param maxClassNo maximum class number
  128. * @return whether another split is possible or not
  129. */
  130. bool entropyLeftRight ( const FeatureValuesUnsorted & values,
  131. double threshold,
  132. double* stat_left,
  133. double* stat_right,
  134. double & entropy_left,
  135. double & entropy_right,
  136. double & count_left,
  137. double & count_right,
  138. int maxClassNo );
  139. public:
  140. /** simple constructor */
  141. DTBOblique ( const NICE::Config *conf,
  142. std::string section = "DTBOblique" );
  143. /** simple destructor */
  144. virtual ~DTBOblique();
  145. /**
  146. * @brief initial building method
  147. * @param fp feature pool
  148. * @param examples all examples of the training
  149. * @param maxClassNo maximum class number
  150. * @return Pointer to root/parent node
  151. */
  152. DecisionNode *build ( const FeaturePool &fp,
  153. const Examples &examples,
  154. int maxClassNo );
  155. };
  156. } //namespace
  157. #endif