DTBObliqueLS.h 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. /**
  2. * @file DTBObliqueLS.h
  3. * @brief oblique decision tree
  4. * @author Sven Sickert
  5. * @date 10/15/2014
  6. */
  7. #ifndef DTBOBLIQUELSINCLUDE
  8. #define DTBOBLIQUELSINCLUDE
  9. #include "core/vector/VectorT.h"
  10. #include "core/vector/MatrixT.h"
  11. #include "core/basics/Config.h"
  12. #include "DecisionTreeBuilder.h"
  13. #include "vislearning/cbaselib/CachedExample.h"
  14. namespace OBJREC {
  15. struct SplitInfo {
  16. double threshold;
  17. double informationGain;
  18. double entropyLeft;
  19. double entropyRight;
  20. double *distLeft;
  21. double *distRight;
  22. NICE::Vector params;
  23. };
  24. /** random oblique decision tree */
  25. class DTBObliqueLS : public DecisionTreeBuilder
  26. {
  27. protected:
  28. /////////////////////////
  29. /////////////////////////
  30. // PROTECTED VARIABLES //
  31. /////////////////////////
  32. /////////////////////////
  33. /** Whether to use shannon entropy or not */
  34. bool useShannonEntropy;
  35. /** Whether to save indices in leaves or not */
  36. bool saveIndices;
  37. /** Whether to use one-vs-one or one-vs-all for multiclass scenarios */
  38. bool useOneVsOne;
  39. /** Whether to increase the influence of regularization over time or not */
  40. bool useDynamicRegularization;
  41. /** Amount of steps for complete search for best threshold */
  42. int splitSteps;
  43. /** Maximum allowed depth of a tree */
  44. int maxDepth;
  45. /* Minimum amount of features in a leaf node */
  46. int minExamples;
  47. /** Regularization type */
  48. int regularizationType;
  49. /** Minimum entropy to continue with splitting */
  50. double minimumEntropy;
  51. /** Minimum information gain to continue with splitting */
  52. double minimumInformationGain;
  53. /** Regularization parameter */
  54. double lambdaInit;
  55. /////////////////////////
  56. /////////////////////////
  57. // PROTECTED METHODS //
  58. /////////////////////////
  59. /////////////////////////
  60. /**
  61. * @brief adaptDataAndLabelForMultiClass
  62. * @param posClass positive class number
  63. * @param negClass negative class number
  64. * @param matX adapted data matrix
  65. * @param vecY adapted label vector
  66. * @param weights example weights
  67. * @return whether positive and negative classes have examples or not
  68. */
  69. bool adaptDataAndLabelForMultiClass (
  70. const int posClass,
  71. const int negClass,
  72. NICE::Matrix & matX,
  73. NICE::Vector & vecY );
  74. /**
  75. * @brief get data matrix X and label vector y
  76. * @param fp feature pool
  77. * @param examples all examples of the training
  78. * @param examples_selection indeces of selected example subset
  79. * @param matX data matrix (amountExamples x amountParameters)
  80. * @param vecY label vector (amountExamples)
  81. */
  82. void getDataAndLabel(
  83. const FeaturePool &fp,
  84. const Examples &examples,
  85. const std::vector<int> & examples_selection,
  86. NICE::Matrix &X,
  87. NICE::Vector &y,
  88. NICE::Vector &w );
  89. /**
  90. * @brief return a regularization matrix of size (dimParams)x(dimParams)
  91. * @param X data matrix
  92. * @param XTXreg return regularized X'*X
  93. * @param regOption which kind of regularization
  94. * @param lambda regularization parameter (weigthing)
  95. */
  96. void regularizeDataMatrix (
  97. const NICE::Matrix & X,
  98. NICE::Matrix &XTXreg,
  99. const int regOption,
  100. const double lambda );
  101. /**
  102. * @brief find best threshold for current splitting
  103. * @param values feature values
  104. * @param bestSplitInfo struct including best split information
  105. * @param e entropy before split
  106. * @param maxClassNo maximum class number
  107. */
  108. void findBestSplitThreshold (
  109. FeatureValuesUnsorted & values,
  110. SplitInfo & bestSplitInfo,
  111. const NICE::Vector & params,
  112. const double & e,
  113. const int & maxClassNo );
  114. /**
  115. * @brief recursive building method
  116. * @param fp feature pool
  117. * @param examples all examples of the training
  118. * @param examples_selection indeces of selected example subset
  119. * @param distribution class distribution in current node
  120. * @param entropy current entropy
  121. * @param maxClassNo maximum class number
  122. * @param depth current depth
  123. * @return Pointer to root/parent node
  124. */
  125. DecisionNode *buildRecursive (
  126. const FeaturePool & fp,
  127. const Examples & examples,
  128. std::vector<int> & examples_selection,
  129. FullVector & distribution,
  130. double entropy,
  131. int maxClassNo,
  132. int depth,
  133. double curLambda );
  134. /**
  135. * @brief compute entropy for left and right child
  136. * @param values feature values
  137. * @param threshold threshold for split
  138. * @param stat_left statistics for left child
  139. * @param stat_right statistics for right child
  140. * @param entropy_left entropy for left child
  141. * @param entropy_right entropy for right child
  142. * @param count_left amount of features in left child
  143. * @param count_right amount of features in right child
  144. * @param maxClassNo maximum class number
  145. * @return whether another split is possible or not
  146. */
  147. bool entropyLeftRight ( const FeatureValuesUnsorted & values,
  148. double threshold,
  149. double* stat_left,
  150. double* stat_right,
  151. double & entropy_left,
  152. double & entropy_right,
  153. double & count_left,
  154. double & count_right,
  155. int maxClassNo );
  156. public:
  157. /** simple constructor */
  158. DTBObliqueLS ( const NICE::Config *conf,
  159. std::string section = "DTBObliqueLS" );
  160. /** simple destructor */
  161. virtual ~DTBObliqueLS();
  162. /**
  163. * @brief initial building method
  164. * @param fp feature pool
  165. * @param examples all examples of the training
  166. * @param maxClassNo maximum class number
  167. * @return Pointer to root/parent node
  168. */
  169. DecisionNode *build ( const FeaturePool &fp,
  170. const Examples &examples,
  171. int maxClassNo );
  172. };
  173. } //namespace
  174. #endif