Эх сурвалжийг харах

fixes to CodebookRandomForest: added store and restore, etc

Johannes Ruehle 11 жил өмнө
parent
commit
459db1c0a8

+ 142 - 16
features/simplefeatures/CodebookRandomForest.cpp

@@ -16,6 +16,8 @@ using namespace OBJREC;
 using namespace std;
 using namespace NICE;
 
+#undef DEBUGPRUNING
+
 CodebookRandomForest::CodebookRandomForest( int maxDepth, int restrictedCodebookSize )
 {
 	this->clusterforest = NULL;
@@ -109,8 +111,10 @@ void CodebookRandomForest::pruneForest ()
     while ( (leafs > restrictedCodebookSize) && (lastLevelInnerNodes.size() > 0) )
     {
 		const triplet<double, long, DecisionNode *> & nodemi = lastLevelInnerNodes.top();
-		//double current_mi = -nodemi.first;
-		//fprintf (stderr, "CodebookRandomForest: %d contract leaf with mutual information %f\n", leafs, current_mi );
+        #ifdef DEBUGPRUNING
+            double current_mi = -nodemi.first;
+            fprintf (stderr, "CodebookRandomForest: %d contract leaf with mutual information %f\n", leafs, current_mi );
+        #endif
 		DecisionNode *node = nodemi.third;
 		lastLevelInnerNodes.pop();
 
@@ -412,7 +416,40 @@ void CodebookRandomForest::voteAndClassify ( const NICE::Vector & feature, NICE:
 			distribution.add ( sDistribution );
     }
 
-	distribution.normalize();
+    distribution.normalize();
+}
+
+void CodebookRandomForest::voteAndClassify(const Vector &feature, SparseVector &votes, Vector &distribution) const
+{
+    vector<DecisionNode *> leafNodes;
+    NICE::Vector *x = new NICE::Vector ( feature );
+    Example pe ( x );
+    clusterforest->getLeafNodes ( pe, leafNodes, maxDepth );
+    delete x;
+
+    for ( vector<DecisionNode *>::const_iterator j = leafNodes.begin();
+        j != leafNodes.end(); j++ )
+    {
+        map<DecisionNode *, int>::const_iterator k = leafMap.find ( *j );
+        DecisionNode *node = *j;
+
+        assert ( k != leafMap.end() );
+        int leafindex = k->second;
+        votes.insert ( votes.begin(), pair<int, double> ( leafindex, 1.0 ) );
+
+        FullVector sDistribution ( node->distribution );
+        sDistribution.normalize();
+        if ( distribution.size() == 0 )
+        {
+            distribution.resize(sDistribution.size() );
+            distribution.set(0.0f);
+        }
+        for(int i = 0; i< sDistribution.size(); i++)
+            distribution[i] += sDistribution[i];
+
+    }
+
+    distribution.normalizeL2();
 }
 
 void CodebookRandomForest::add ( const Codebook *codebook )
@@ -434,21 +471,110 @@ void CodebookRandomForest::clear ()
 
 void CodebookRandomForest::restore ( istream & is, int format )
 {
-	if(this->clusterforest == NULL)
-		this->clusterforest = new FPCRandomForests ();
-    Codebook::restore(is, format);
-    int maxClassNo = 0;
-    is >> maxClassNo;
-    clusterforest->setMaxClassNo( maxClassNo );
-    clusterforest->restore ( is, format );
-    buildLeafMap();
+    if (is.good())
+    {
+        std::string tmp;
+        is >> tmp; //class name
+
+        if ( ! this->isStartTag( tmp, "CodebookRandomForest" ) )
+        {
+            std::cerr << " WARNING - attempt to restore CodebookRandomForest, but start flag " << tmp << " does not match! Aborting... " << std::endl;
+            throw;
+        }
+
+        if(this->clusterforest == NULL)
+            this->clusterforest = new FPCRandomForests ();
+
+
+        bool b_endOfBlock = false;
+
+        while ( !b_endOfBlock )
+        {
+            is >> tmp; // start of block
+
+            if ( this->isEndTag( tmp, "CodebookRandomForest" )  || is.eof() )
+            {
+                b_endOfBlock = true;
+                continue;
+            }
+
+            tmp = this->removeStartTag ( tmp );
+            if ( tmp.compare("baseclass") == 0 )
+            {
+                Codebook::restore(is, format);
+                is >> tmp; // end of block
+                tmp = this->removeEndTag ( tmp );
+            }
+            else if ( tmp.compare("maxDepth") == 0 )
+            {
+                is >> maxDepth;
+                is >> tmp; // end of block
+                tmp = this->removeEndTag ( tmp );
+            }
+            else if ( tmp.compare("restrictedCodebookSize") == 0 )
+            {
+                is >> restrictedCodebookSize;
+                is >> tmp; // end of block
+                tmp = this->removeEndTag ( tmp );
+            }
+            else if ( tmp.compare("maxClassNo") == 0 )
+            {
+                int maxClassNo = 0;
+                is >> maxClassNo;
+                is >> tmp; // end of block
+                tmp = this->removeEndTag ( tmp );
+
+                if(clusterforest != NULL)
+                    clusterforest->setMaxClassNo(maxClassNo);
+            }
+            else if ( tmp.compare("clusterforest") == 0 )
+            {
+                clusterforest->restore ( is, format );
+                is >> tmp; // end of block
+                tmp = this->removeEndTag ( tmp );
+            }
+        }
+
+        buildLeafMap();
+    }
 }
 
 void CodebookRandomForest::store ( ostream & os, int format ) const
 {
-    Codebook::store ( os, format );
-    os << endl;
-    os << clusterforest->getMaxClassNo() << endl;
-    clusterforest->store ( os, format );
-    os << endl;
+    if (os.good())
+    {
+        // show starting point
+        os << this->createStartTag( "CodebookRandomForest" ) << std::endl;
+
+        os.precision (numeric_limits<double>::digits10 + 1);
+
+        os << this->createStartTag( "baseclass" ) << std::endl;
+        Codebook::store ( os, format );
+        os << this->createEndTag( "baseclass" ) << std::endl;
+
+        os << this->createStartTag( "maxDepth" ) << std::endl;
+        os << maxDepth << std::endl;
+        os << this->createEndTag( "maxDepth" ) << std::endl;
+
+        os << this->createStartTag( "restrictedCodebookSize" ) << std::endl;
+        os << restrictedCodebookSize << std::endl;
+        os << this->createEndTag( "restrictedCodebookSize" ) << std::endl;
+
+        os << this->createStartTag( "maxClassNo" ) << std::endl;
+        os << clusterforest->getMaxClassNo() << endl;
+        os << this->createEndTag( "maxClassNo" ) << std::endl;
+
+        os << this->createStartTag( "clusterforest" ) << std::endl;
+        clusterforest->store ( os, format );
+        os << this->createEndTag( "clusterforest" ) << std::endl;
+/*        Codebook::store ( os, format );
+        os << maxDepth << endl;
+        os << restrictedCodebookSize << endl;
+        os << clusterforest->getMaxClassNo() << endl;
+        clusterforest->store ( os, format );
+        os << endl;
+*/
+        // done
+        os << this->createEndTag( "CodebookRandomForest" ) << std::endl;
+    }
 }

+ 12 - 0
features/simplefeatures/CodebookRandomForest.h

@@ -91,6 +91,8 @@ class CodebookRandomForest : public Codebook
 
 		/** normal codebook voting, but additionally returns a probability distribution for the class label **/
         void voteAndClassify ( const NICE::Vector & feature, NICE::SparseVector & votes, FullVector & distribution ) const;
+        /** normal codebook voting, but additionally returns a probability distribution for the class label **/
+        void voteAndClassify ( const NICE::Vector & feature, NICE::SparseVector & votes, NICE::Vector & distribution ) const;
         virtual void voteVA ( const NICE::Vector & feature, NICE::SparseVector & votes ) const {
             this->vote(feature, votes);
         }
@@ -114,6 +116,16 @@ class CodebookRandomForest : public Codebook
 
 		/** write the codebook to a stream */
 		void store ( std::ostream & os, int format = 0) const;
+
+        int getMaxDepth() const
+        {
+            return this->maxDepth;
+        }
+
+        int getRestrictedCodebookSize() const
+        {
+            return restrictedCodebookSize;
+        }
 };