RTBLinear.h 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. /**
  2. * @file RTBLinear.h
  3. * @brief random regression tree, which learns a LSE-model in every inner node during training
  4. * @author Frank Prüfer
  5. * @date 09/17/2013
  6. */
  7. #ifndef RTBLINEARINCLUDE
  8. #define RTBLINEARINCLUDE
  9. #include <vector>
  10. #include "core/vector/VectorT.h"
  11. #include "core/vector/VVector.h"
  12. #include "core/basics/Config.h"
  13. #include "RegressionTreeBuilder.h"
  14. namespace OBJREC {
  15. /** random regression tree */
  16. class RTBLinear : public RegressionTreeBuilder
  17. {
  18. protected:
  19. int random_split_tests;
  20. int random_features;
  21. int max_depth;
  22. int min_examples;
  23. double minimum_error_reduction;
  24. int random_split_mode;
  25. /** save indices in leaves */
  26. bool save_indices;
  27. enum {
  28. RANDOM_SPLIT_INDEX = 0,
  29. RANDOM_SPLIT_UNIFORM
  30. };
  31. RegressionNode *buildRecursive ( const NICE::VVector & x,
  32. const NICE::Vector & y,
  33. std::vector<int> & selection,
  34. int depth);
  35. void computeLinearLSError ( const NICE::VVector & x,
  36. const NICE::Vector & y,
  37. const int & numEx,
  38. double & lsError);
  39. bool errorReductionLeftRight ( const std::vector< std::pair< double, int > > values,
  40. const NICE::Vector & y,
  41. double threshold,
  42. double & error_left,
  43. double & error_right,
  44. int & count_left,
  45. int & count_right );
  46. public:
  47. /** simple constructor */
  48. RTBLinear( const NICE::Config *conf, std::string section = "RTBLinear" );
  49. /** simple destructor */
  50. virtual ~RTBLinear();
  51. RegressionNode *build ( const NICE::VVector & x,
  52. const NICE::Vector & y );
  53. };
  54. } // namespace
  55. #endif