DTBOblique.h 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. /**
  2. * @file DTBOblique.h
  3. * @brief oblique decision tree
  4. * @author Sven Sickert
  5. * @date 10/15/2014
  6. */
  7. #ifndef DTBOBLIQUEINCLUDE
  8. #define DTBOBLIQUEINCLUDE
  9. #include "core/vector/VectorT.h"
  10. #include "core/vector/MatrixT.h"
  11. #include "core/basics/Config.h"
  12. #include "DecisionTreeBuilder.h"
  13. #include "vislearning/cbaselib/CachedExample.h"
  14. namespace OBJREC {
  15. struct SplitInfo {
  16. double threshold;
  17. double informationGain;
  18. double entropyLeft;
  19. double entropyRight;
  20. double *distLeft;
  21. double *distRight;
  22. NICE::Vector params;
  23. };
  24. /** random oblique decision tree */
  25. class DTBOblique : public DecisionTreeBuilder
  26. {
  27. protected:
  28. /////////////////////////
  29. /////////////////////////
  30. // PROTECTED VARIABLES //
  31. /////////////////////////
  32. /////////////////////////
  33. /** Whether to use shannon entropy or not */
  34. bool useShannonEntropy;
  35. /** Whether to save indices in leaves or not */
  36. bool saveIndices;
  37. /** Whether to use one-vs-one or one-vs-all for multiclass scenarios */
  38. bool useOneVsOne;
  39. /** Amount of steps for complete search for best threshold */
  40. int splitSteps;
  41. /** Maximum allowed depth of a tree */
  42. int maxDepth;
  43. /* Minimum amount of features in a leaf node */
  44. int minExamples;
  45. /** Regularization type */
  46. int regularizationType;
  47. /** Minimum entropy to continue with splitting */
  48. double minimumEntropy;
  49. /** Minimum information gain to continue with splitting */
  50. double minimumInformationGain;
  51. /** Regularization parameter */
  52. double lambdaInit;
  53. /////////////////////////
  54. /////////////////////////
  55. // PROTECTED METHODS //
  56. /////////////////////////
  57. /////////////////////////
  58. /**
  59. * @brief adaptDataAndLabelForMultiClass
  60. * @param posClass positive class number
  61. * @param negClass negative class number
  62. * @param matX adapted data matrix
  63. * @param vecY adapted label vector
  64. * @param posHasExamples whether positive class has examples or not
  65. * @param negHasExamples whether negative class has examples or not
  66. */
  67. void adaptDataAndLabelForMultiClass (
  68. const int posClass,
  69. const int negClass,
  70. NICE::Matrix & matX,
  71. NICE::Vector & vecY,
  72. bool & posHasExamples,
  73. bool & negHasExamples );
  74. /**
  75. * @brief get data matrix X and label vector y
  76. * @param fp feature pool
  77. * @param examples all examples of the training
  78. * @param examples_selection indeces of selected example subset
  79. * @param matX data matrix (amountExamples x amountParameters)
  80. * @param vecY label vector (amountExamples)
  81. */
  82. void getDataAndLabel(
  83. const FeaturePool &fp,
  84. const Examples &examples,
  85. const std::vector<int> & examples_selection,
  86. NICE::Matrix &matX,
  87. NICE::Vector &vecY );
  88. /**
  89. * @brief return a regularization matrix of size (dimParams)x(dimParams)
  90. * @param X data matrix
  91. * @param XTXreg return regularized X'*X
  92. * @param regOption which kind of regularization
  93. * @param lambda regularization parameter (weigthing)
  94. */
  95. void regularizeDataMatrix (
  96. const NICE::Matrix & X,
  97. NICE::Matrix &XTXreg,
  98. const int regOption,
  99. const double lambda );
  100. /**
  101. * @brief find best threshold for current splitting
  102. * @param values feature values
  103. * @param bestSplitInfo struct including best split information
  104. * @param e entropy before split
  105. * @param maxClassNo maximum class number
  106. */
  107. void findBestSplitThreshold (
  108. FeatureValuesUnsorted & values,
  109. SplitInfo & bestSplitInfo,
  110. const NICE::Vector & beta,
  111. const double & e,
  112. const int & maxClassNo );
  113. /**
  114. * @brief recursive building method
  115. * @param fp feature pool
  116. * @param examples all examples of the training
  117. * @param examples_selection indeces of selected example subset
  118. * @param distribution class distribution in current node
  119. * @param entropy current entropy
  120. * @param maxClassNo maximum class number
  121. * @param depth current depth
  122. * @return Pointer to root/parent node
  123. */
  124. DecisionNode *buildRecursive (
  125. const FeaturePool & fp,
  126. const Examples & examples,
  127. std::vector<int> & examples_selection,
  128. FullVector & distribution,
  129. double entropy,
  130. int maxClassNo,
  131. int depth,
  132. double curLambda );
  133. /**
  134. * @brief compute entropy for left and right child
  135. * @param values feature values
  136. * @param threshold threshold for split
  137. * @param stat_left statistics for left child
  138. * @param stat_right statistics for right child
  139. * @param entropy_left entropy for left child
  140. * @param entropy_right entropy for right child
  141. * @param count_left amount of features in left child
  142. * @param count_right amount of features in right child
  143. * @param maxClassNo maximum class number
  144. * @return whether another split is possible or not
  145. */
  146. bool entropyLeftRight ( const FeatureValuesUnsorted & values,
  147. double threshold,
  148. double* stat_left,
  149. double* stat_right,
  150. double & entropy_left,
  151. double & entropy_right,
  152. double & count_left,
  153. double & count_right,
  154. int maxClassNo );
  155. public:
  156. /** simple constructor */
  157. DTBOblique ( const NICE::Config *conf,
  158. std::string section = "DTBOblique" );
  159. /** simple destructor */
  160. virtual ~DTBOblique();
  161. /**
  162. * @brief initial building method
  163. * @param fp feature pool
  164. * @param examples all examples of the training
  165. * @param maxClassNo maximum class number
  166. * @return Pointer to root/parent node
  167. */
  168. DecisionNode *build ( const FeaturePool &fp,
  169. const Examples &examples,
  170. int maxClassNo );
  171. };
  172. } //namespace
  173. #endif