/** * @file RegressionTree.cpp * @brief regression tree implementation * @author Sven Sickert * @date 06/19/2013 */ #include #include #include "vislearning/regression/randomforest/RegressionTree.h" using namespace OBJREC; using namespace std; using namespace NICE; RegressionTree::RegressionTree( const Config *_conf ) : conf(_conf) { root = NULL; } RegressionTree::~RegressionTree() { deleteNodes ( root ); } void RegressionTree::statistics ( int & depth, int & count ) const { if ( root == NULL ) { depth = 0; count = 0; } else { root->statistics ( depth, count ); } } void RegressionTree::traverse ( const Vector & x, double & predVal ) { assert( root != NULL ); root->traverse ( x, predVal ); } void RegressionTree::deleteNodes ( RegressionNode *tree ) { if ( tree != NULL ) { deleteNodes ( tree->left ); deleteNodes ( tree->right ); delete tree; } } void RegressionTree::clear () { deleteNodes ( root ); } void RegressionTree::resetCounters () { if ( root != NULL ) root->resetCounters (); } void RegressionTree::indexDescendants ( map > & index, long & maxindex ) const { if ( root != NULL ) root->indexDescendants ( index, maxindex, 0 ); } RegressionNode *RegressionTree::getLeafNode ( Vector & x, int maxdepth ) { return root->getLeafNode ( x, maxdepth ); } void RegressionTree::getLeaves( RegressionNode *node, vector &leaves) { if(node->left == NULL && node->right == NULL) { leaves.push_back(node); return; } getLeaves(node->right, leaves); getLeaves(node->left, leaves); } vector RegressionTree::getAllLeafNodes() { vector leaves; getLeaves(root, leaves); return leaves; } void RegressionTree::setRoot ( RegressionNode *newroot ) { root = newroot; } RegressionNode *RegressionTree::pruneTreeLeastSquares ( RegressionNode *node, double minErrorReduction, double & lsError ) { if ( node == NULL ) return NULL; lsError = node->lsError; double leftError, rightError; node->left = pruneTreeLeastSquares ( node->left, minErrorReduction, leftError ); node->right = pruneTreeLeastSquares ( node->right, minErrorReduction, rightError ); if (node->left != NULL && node->right != NULL) { if (lsError-leftError-rightError < minErrorReduction) { deleteNodes( node->left ); deleteNodes( node->right ); } } return node; } void RegressionTree::pruneTreeLeastSquares ( double minErrorReduction ) { int depth, count; statistics ( depth, count ); fprintf (stderr, "RegressionTree::pruneTreeLeastSquares: depth %d count %d\n", depth, count ); double tmp; root = pruneTreeLeastSquares ( root, minErrorReduction, tmp ); statistics ( depth, count ); fprintf (stderr, "RegressionTree::pruneTreeLeastSquares: depth %d count %d (modified)\n", depth, count ); } void RegressionTree::store (ostream & os, int format) const { if ( root == NULL ) return; // indexing map > index; index.insert ( pair > ( NULL, pair ( 0, 0 ) ) ); index.insert ( pair > ( root, pair ( 1, 0 ) ) ); long maxindex = 1; root->indexDescendants ( index, maxindex, 0 ); for ( map >::iterator i = index.begin(); i != index.end(); i++ ) { RegressionNode *node = i->first; if ( node == NULL ) continue; long ind = i->second.first; long ind_l = index[ node->left ].first; long ind_r = index[ node->right ].first; os << "NODE " << ind << " " << ind_l << " " << ind_r << endl; if ( !node->isLeaf() ) { os << node->f; os << endl; os << node->threshold; os << endl; } else { os << "LEAF"; os << endl; } os << node->lsError << " " << -1 << endl; } } void RegressionTree::restore (istream & is, int format) { // indexing map index; map > descendants; index.insert ( pair ( 0, NULL ) ); // refactor-nice.pl: check this substitution // old: string tag; std::string tag; while ( (! is.eof()) && ( (is >> tag) && (tag == "NODE") ) ) { long ind; long ind_l; long ind_r; if (! (is >> ind)) break; if (! (is >> ind_l)) break; if (! (is >> ind_r)) break; descendants.insert ( pair > ( ind, pair ( ind_l, ind_r ) ) ); RegressionNode *node = new RegressionNode(); index.insert ( pair ( ind, node ) ); std::string feature_tag; is >> feature_tag; if ( feature_tag != "LEAF" ) { is >> node->f; is >> node->threshold; } is >> node->lsError; } // connecting the tree for ( map::const_iterator it = index.begin(); it != index.end(); it++ ) { RegressionNode *node = it->second; if ( node == NULL ) continue; long ind_l = descendants[it->first].first; long ind_r = descendants[it->first].second; map::const_iterator il = index.find ( ind_l ); map::const_iterator ir = index.find ( ind_r ); if ( ( il == index.end() ) || ( ir == index.end() ) ) { fprintf (stderr, "File inconsistent: unable to build tree\n"); exit(-1); } RegressionNode *left = il->second; RegressionNode *right = ir->second; node->left = left; node->right = right; } map::const_iterator iroot = index.find ( 1 ); if ( iroot == index.end() ) { fprintf (stderr, "File inconsistent: unable to build tree (root node not found)\n"); exit(-1); } root = iroot->second; }