DTBObliqueLS.h 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. /**
  2. * @file DTBObliqueLS.h
  3. * @brief oblique decision tree
  4. * @author Sven Sickert
  5. * @date 10/15/2014
  6. */
  7. #ifndef DTBOBLIQUELSINCLUDE
  8. #define DTBOBLIQUELSINCLUDE
  9. #include "core/vector/VectorT.h"
  10. #include "core/vector/MatrixT.h"
  11. #include "core/basics/Config.h"
  12. #include "DecisionTreeBuilder.h"
  13. #include "SplittingCriterion.h"
  14. #include "vislearning/cbaselib/CachedExample.h"
  15. namespace OBJREC {
  16. struct SplitInfo {
  17. double threshold;
  18. double purity;
  19. double entropy;
  20. double *distLeft;
  21. double *distRight;
  22. NICE::Vector params;
  23. };
  24. /** random oblique decision tree */
  25. class DTBObliqueLS : public DecisionTreeBuilder
  26. {
  27. protected:
  28. /////////////////////////
  29. /////////////////////////
  30. // PROTECTED VARIABLES //
  31. /////////////////////////
  32. /////////////////////////
  33. /** Splitting criterion */
  34. SplittingCriterion *splitCriterion;
  35. /** Whether to save indices in leaves or not */
  36. bool saveIndices;
  37. /** Whether to use one-vs-one (0), one-vs-all (1) or many-vs-many (2) for multiclass scenarios */
  38. int multiClassMode;
  39. /** Whether to increase the influence of regularization over time or not */
  40. bool useDynamicRegularization;
  41. /** Amount of steps for complete search for best threshold */
  42. int splitSteps;
  43. /** Maximum allowed depth of a tree */
  44. int maxDepth;
  45. /** Regularization type */
  46. int regularizationType;
  47. /** Regularization parameter */
  48. double lambdaInit;
  49. /////////////////////////
  50. /////////////////////////
  51. // PROTECTED METHODS //
  52. /////////////////////////
  53. /////////////////////////
  54. /**
  55. * @brief adaptDataAndLabelForMultiClass
  56. * @param posClass positive class number
  57. * @param negClass negative class number
  58. * @param matX adapted data matrix
  59. * @param vecY adapted label vector
  60. * @param weights example weights
  61. * @return whether positive and negative classes have examples or not
  62. */
  63. bool adaptDataAndLabelForMultiClass (
  64. const int posClass,
  65. const int negClass,
  66. NICE::Matrix & matX,
  67. NICE::Vector & vecY );
  68. /**
  69. * @brief get data matrix X and label vector y
  70. * @param fp feature pool
  71. * @param examples all examples of the training
  72. * @param examples_selection indeces of selected example subset
  73. * @param matX data matrix (amountExamples x amountParameters)
  74. * @param vecY label vector (amountExamples)
  75. */
  76. void getDataAndLabel(
  77. const FeaturePool &fp,
  78. const Examples &examples,
  79. const std::vector<int> & examples_selection,
  80. NICE::Matrix &X,
  81. NICE::Vector &y,
  82. NICE::Vector &w );
  83. /**
  84. * @brief return a regularization matrix of size (dimParams)x(dimParams)
  85. * @param X data matrix
  86. * @param XTXreg return regularized X'*X
  87. * @param regOption which kind of regularization
  88. * @param lambda regularization parameter (weigthing)
  89. */
  90. void regularizeDataMatrix (
  91. const NICE::Matrix & X,
  92. NICE::Matrix &XTXreg,
  93. const int regOption,
  94. const double lambda );
  95. /**
  96. * @brief find best threshold for current splitting
  97. * @param values feature values
  98. * @param bestSplitInfo struct including best split information
  99. * @param params parameter vector for oblique decision
  100. * @param maxClassNo maximum class number
  101. */
  102. void findBestSplitThreshold (
  103. FeatureValuesUnsorted & values,
  104. SplitInfo & bestSplitInfo,
  105. const NICE::Vector & params,
  106. const int & maxClassNo );
  107. /**
  108. * @brief recursive building method
  109. * @param fp feature pool
  110. * @param examples all examples of the training
  111. * @param examples_selection indeces of selected example subset
  112. * @param distribution class distribution in current node
  113. * @param entropy current entropy
  114. * @param maxClassNo maximum class number
  115. * @param depth current depth
  116. * @return Pointer to root/parent node
  117. */
  118. DecisionNode *buildRecursive (
  119. const FeaturePool & fp,
  120. const Examples & examples,
  121. std::vector<int> & examples_selection,
  122. FullVector & distribution,
  123. double entropy,
  124. int maxClassNo,
  125. int depth,
  126. double curLambda );
  127. public:
  128. /** simple constructor */
  129. DTBObliqueLS ( const NICE::Config *conf,
  130. std::string section = "DTBObliqueLS" );
  131. /** simple destructor */
  132. virtual ~DTBObliqueLS();
  133. /**
  134. * @brief initial building method
  135. * @param fp feature pool
  136. * @param examples all examples of the training
  137. * @param maxClassNo maximum class number
  138. * @return Pointer to root/parent node
  139. */
  140. DecisionNode *build ( const FeaturePool &fp,
  141. const Examples &examples,
  142. int maxClassNo );
  143. };
  144. } //namespace
  145. #endif