RegPreRandomForests.cpp 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. /**
  2. * @file RegPreRandomForests.cpp
  3. * @brief Combination of a regression method with a pre-clustering using a random forest
  4. * @author Sven Sickert
  5. * @date 07/12/2013
  6. */
  7. #include "vislearning/regression/regcombination/RegPreRandomForests.h"
  8. #include <iostream>
  9. #include <assert.h>
  10. using namespace OBJREC;
  11. using namespace std;
  12. using namespace NICE;
  13. RegPreRandomForests::RegPreRandomForests(const Config * conf,
  14. const string & section,
  15. RegressionAlgorithm *_leafRegressionPrototype )
  16. : leafRegressionPrototype(_leafRegressionPrototype)
  17. {
  18. string cluster_section = conf->gS ( section, "cluster_section", "RandomForest" );
  19. mEx = conf->gI ( "RTBRandom", "min_examples", 500 );
  20. randomforest = new RegRandomForests( conf, cluster_section );
  21. }
  22. RegPreRandomForests::~RegPreRandomForests()
  23. {
  24. // delete the random forest
  25. if ( randomforest != NULL )
  26. delete randomforest;
  27. // delte all regression methods in the leafs
  28. for ( map<RegressionNode *, RegressionAlgorithm * >::const_iterator it = leafRegressions.begin();
  29. it != leafRegressions.end(); it++ )
  30. {
  31. RegressionAlgorithm * lr = it->second;
  32. delete lr;
  33. }
  34. }
  35. void RegPreRandomForests::teach ( const VVector & X, const Vector & y )
  36. {
  37. randomforest->teach ( X, y );
  38. if ( leafRegressionPrototype != NULL )
  39. {
  40. vector<RegressionNode *> leafNodes;
  41. randomforest->getAllLeafNodes ( leafNodes );
  42. int lsize = leafNodes.size();
  43. int leafNo = 0;
  44. cerr << "leafnodes: " << lsize << endl;
  45. #pragma omp parallel for
  46. for ( int l = 0; l < lsize; l++ )
  47. {
  48. leafNo++;
  49. RegressionNode *node = leafNodes[l];
  50. if ( !node->isLeaf() ){
  51. fprintf( stderr, "RegPreRandomForests::predict: ID #%d not a leaf node!", leafNo );
  52. continue;
  53. }
  54. vector<int> leafTrainInds = node->trainExamplesIndices;
  55. cerr << "Teaching regression method for leaf " << leafNo-1 << "..." << endl;
  56. cerr << "examples in leave: " << leafTrainInds.size() << endl;
  57. assert ( leafTrainInds.size() > 0 );
  58. sort ( leafTrainInds.begin(), leafTrainInds.end() );
  59. NICE::VVector leafTrainData;
  60. vector<double> tmpVals;
  61. for ( int i = 0; i < (int)leafTrainInds.size(); i++ )
  62. {
  63. if ( leafTrainInds[i] >= 0 && leafTrainInds[i] < (int)y.size() )
  64. {
  65. leafTrainData.push_back( X[ leafTrainInds[i] ] );
  66. tmpVals.push_back( y[ leafTrainInds[i] ] );
  67. }
  68. }
  69. if (leafTrainData.size() <= 0 ) continue;
  70. NICE::Vector leafTrainVals( tmpVals );
  71. RegressionAlgorithm *lr = leafRegressionPrototype->clone();
  72. lr->teach( leafTrainData, leafTrainVals );
  73. leafRegressions.insert ( pair< RegressionNode *, RegressionAlgorithm *> ( node, lr ) );
  74. }
  75. }
  76. }
  77. double RegPreRandomForests::predict ( const Vector & x )
  78. {
  79. double pred = 0.0;
  80. vector<RegressionNode *> leafNodes;
  81. // traverse the forest and obtain all innvolved leaf nodes
  82. randomforest->getLeafNodes ( x, leafNodes );
  83. for ( vector<RegressionNode *>::const_iterator it = leafNodes.begin();
  84. it != leafNodes.end(); it++ )
  85. {
  86. RegressionNode *node = *it;
  87. map<RegressionNode *, RegressionAlgorithm *>::const_iterator leafRegressionIt =
  88. leafRegressions.find( node );
  89. if ( leafRegressionIt == leafRegressions.end() )
  90. {
  91. // this leaf has no associated regression method
  92. // -> we will use the random forest result
  93. pred += node->predVal;
  94. continue;
  95. }
  96. RegressionAlgorithm *leafRegression = leafRegressionIt->second;
  97. pred += leafRegression->predict( x );
  98. }
  99. pred /= leafNodes.size();
  100. return pred;
  101. }
  102. void RegPreRandomForests::clear ()
  103. {
  104. map<RegressionNode *, RegressionAlgorithm *>::iterator iter;
  105. for ( iter = leafRegressions.begin(); iter != leafRegressions.end(); iter++ )
  106. {
  107. iter->second->clear();
  108. }
  109. randomforest->clear();
  110. }
  111. void RegPreRandomForests::store ( ostream & os, int format ) const
  112. {
  113. cerr << "RegPreRandomForest::store: not yet implemented" << endl;
  114. }
  115. void RegPreRandomForests::restore ( istream& is, int format )
  116. {
  117. cerr << "RegPreRandomForest::restore: not yet implemented" << endl;
  118. }