DTBOblique.h 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. /**
  2. * @file DTBOblique.h
  3. * @brief oblique decision tree
  4. * @author Sven Sickert
  5. * @date 10/15/2014
  6. */
  7. #ifndef DTBOBLIQUEINCLUDE
  8. #define DTBOBLIQUEINCLUDE
  9. #include "core/vector/VectorT.h"
  10. #include "core/vector/MatrixT.h"
  11. #include "core/basics/Config.h"
  12. #include "DecisionTreeBuilder.h"
  13. #include "vislearning/cbaselib/CachedExample.h"
  14. namespace OBJREC {
  15. /** random oblique decision tree */
  16. class DTBOblique : public DecisionTreeBuilder
  17. {
  18. protected:
  19. /////////////////////////
  20. /////////////////////////
  21. // PROTECTED VARIABLES //
  22. /////////////////////////
  23. /////////////////////////
  24. /** Amount of steps for complete search for best threshold */
  25. int split_steps;
  26. /** Maximum allowed depth of a tree */
  27. int max_depth;
  28. /* Minimum amount of features in a leaf node */
  29. int min_examples;
  30. /** Minimum entropy to continue with splitting */
  31. int minimum_entropy;
  32. /** Minimum information gain to continue with splitting */
  33. int minimum_information_gain;
  34. /** Whether to use shannon entropy or not */
  35. bool use_shannon_entropy;
  36. /** Whether to save indices in leaves or not */
  37. bool save_indices;
  38. /** Regularization parameter */
  39. double lambdaInit;
  40. /** Regularization type */
  41. int regularizationType;
  42. /////////////////////////
  43. /////////////////////////
  44. // PROTECTED METHODS //
  45. /////////////////////////
  46. /////////////////////////
  47. /**
  48. * @brief get data matrix X and label vector y
  49. * @param fp feature pool
  50. * @param examples all examples of the training
  51. * @param examples_selection indeces of selected example subset
  52. * @param matX data matrix (amountExamples x amountParameters)
  53. * @param vecY label vector (amountExamples)
  54. */
  55. void getDataAndLabel(
  56. const FeaturePool &fp,
  57. const Examples &examples,
  58. const std::vector<int> & examples_selection,
  59. NICE::Matrix &matX,
  60. NICE::Vector &vecY );
  61. /**
  62. * @brief return a regularization matrix of size (dimParams)x(dimParams)
  63. * @param X data matrix
  64. * @param XTXreg return regularized X'*X
  65. * @param regOption which kind of regularization
  66. * @param lambda regularization parameter (weigthing)
  67. */
  68. void regularizeDataMatrix (
  69. const NICE::Matrix & X,
  70. NICE::Matrix &XTXreg,
  71. const int regOption,
  72. const double lambda );
  73. /**
  74. * @brief recursive building method
  75. * @param fp feature pool
  76. * @param examples all examples of the training
  77. * @param examples_selection indeces of selected example subset
  78. * @param distribution class distribution in current node
  79. * @param entropy current entropy
  80. * @param maxClassNo maximum class number
  81. * @param depth current depth
  82. * @return Pointer to root/parent node
  83. */
  84. DecisionNode *buildRecursive (
  85. const FeaturePool & fp,
  86. const Examples & examples,
  87. std::vector<int> & examples_selection,
  88. FullVector & distribution,
  89. double entropy,
  90. int maxClassNo,
  91. int depth,
  92. double curLambda );
  93. /**
  94. * @brief compute entropy for left and right child
  95. * @param values feature values
  96. * @param threshold threshold for split
  97. * @param stat_left statistics for left child
  98. * @param stat_right statistics for right child
  99. * @param entropy_left entropy for left child
  100. * @param entropy_right entropy for right child
  101. * @param count_left amount of features in left child
  102. * @param count_right amount of features in right child
  103. * @param maxClassNo maximum class number
  104. * @return whether another split is possible or not
  105. */
  106. bool entropyLeftRight ( const FeatureValuesUnsorted & values,
  107. double threshold,
  108. double* stat_left,
  109. double* stat_right,
  110. double & entropy_left,
  111. double & entropy_right,
  112. double & count_left,
  113. double & count_right,
  114. int maxClassNo );
  115. public:
  116. /** simple constructor */
  117. DTBOblique ( const NICE::Config *conf,
  118. std::string section = "DTBOblique" );
  119. /** simple destructor */
  120. virtual ~DTBOblique();
  121. /**
  122. * @brief initial building method
  123. * @param fp feature pool
  124. * @param examples all examples of the training
  125. * @param maxClassNo maximum class number
  126. * @return Pointer to root/parent node
  127. */
  128. DecisionNode *build ( const FeaturePool &fp,
  129. const Examples &examples,
  130. int maxClassNo );
  131. };
  132. } //namespace
  133. #endif