/** * @file DTBObliqueLS.h * @brief oblique decision tree * @author Sven Sickert * @date 10/15/2014 */ #ifndef DTBOBLIQUELSINCLUDE #define DTBOBLIQUELSINCLUDE #include "core/vector/VectorT.h" #include "core/vector/MatrixT.h" #include "core/basics/Config.h" #include "DecisionTreeBuilder.h" #include "SplittingCriterion.h" #include "vislearning/cbaselib/CachedExample.h" namespace OBJREC { struct SplitInfo { double threshold; double purity; double entropy; double *distLeft; double *distRight; NICE::Vector params; }; /** random oblique decision tree */ class DTBObliqueLS : public DecisionTreeBuilder { protected: ///////////////////////// ///////////////////////// // PROTECTED VARIABLES // ///////////////////////// ///////////////////////// /** Splitting criterion */ SplittingCriterion *splitCriterion; /** Whether to save indices in leaves or not */ bool saveIndices; /** Whether to use one-vs-one (0), one-vs-all (1) or many-vs-many (2) for multiclass scenarios */ int multiClassMode; /** Whether to increase the influence of regularization over time or not */ bool useDynamicRegularization; /** Amount of steps for complete search for best threshold */ int splitSteps; /** Maximum allowed depth of a tree */ int maxDepth; /** Regularization type */ int regularizationType; /** Regularization parameter */ double lambdaInit; ///////////////////////// ///////////////////////// // PROTECTED METHODS // ///////////////////////// ///////////////////////// /** * @brief adaptDataAndLabelForMultiClass * @param posClass positive class number * @param negClass negative class number * @param matX adapted data matrix * @param vecY adapted label vector * @param weights example weights * @return whether positive and negative classes have examples or not */ bool adaptDataAndLabelForMultiClass ( const int posClass, const int negClass, NICE::Matrix & matX, NICE::Vector & vecY ); /** * @brief get data matrix X and label vector y * @param fp feature pool * @param examples all examples of the training * @param examples_selection indeces of selected example subset * @param matX data matrix (amountExamples x amountParameters) * @param vecY label vector (amountExamples) */ void getDataAndLabel( const FeaturePool &fp, const Examples &examples, const std::vector & examples_selection, NICE::Matrix &X, NICE::Vector &y, NICE::Vector &w ); /** * @brief return a regularization matrix of size (dimParams)x(dimParams) * @param X data matrix * @param XTXreg return regularized X'*X * @param regOption which kind of regularization * @param lambda regularization parameter (weigthing) */ void regularizeDataMatrix ( const NICE::Matrix & X, NICE::Matrix &XTXreg, const int regOption, const double lambda ); /** * @brief find best threshold for current splitting * @param values feature values * @param bestSplitInfo struct including best split information * @param params parameter vector for oblique decision * @param maxClassNo maximum class number */ void findBestSplitThreshold ( FeatureValuesUnsorted & values, SplitInfo & bestSplitInfo, const NICE::Vector & params, const int & maxClassNo ); /** * @brief recursive building method * @param fp feature pool * @param examples all examples of the training * @param examples_selection indeces of selected example subset * @param distribution class distribution in current node * @param entropy current entropy * @param maxClassNo maximum class number * @param depth current depth * @return Pointer to root/parent node */ DecisionNode *buildRecursive ( const FeaturePool & fp, const Examples & examples, std::vector & examples_selection, FullVector & distribution, double entropy, int maxClassNo, int depth, double curLambda ); public: /** simple constructor */ DTBObliqueLS ( const NICE::Config *conf, std::string section = "DTBObliqueLS" ); /** simple destructor */ virtual ~DTBObliqueLS(); /** * @brief initial building method * @param fp feature pool * @param examples all examples of the training * @param maxClassNo maximum class number * @return Pointer to root/parent node */ DecisionNode *build ( const FeaturePool &fp, const Examples &examples, int maxClassNo ); }; } //namespace #endif