Bläddra i källkod

random pos/neg class selection

Sven Sickert 10 år sedan
förälder
incheckning
10dc1d5294
1 ändrade filer med 39 tillägg och 32 borttagningar
  1. 39 32
      classifier/fpclassifier/randomforest/DTBObliqueLS.cpp

+ 39 - 32
classifier/fpclassifier/randomforest/DTBObliqueLS.cpp

@@ -102,6 +102,8 @@ bool DTBObliqueLS::adaptDataAndLabelForMultiClass (
 {
     bool posHasExamples = false;
     bool negHasExamples = false;
+    int posCount = 0;
+    int negCount = 0;
 
     // One-vs-one: Transforming into {-1,0,+1} problem
     if ( useOneVsOne )
@@ -111,11 +113,13 @@ bool DTBObliqueLS::adaptDataAndLabelForMultiClass (
             {
                 y[i] = 1.0;
                 posHasExamples = true;
+                posCount++;
             }
             else if ( y[i] == negClass )
             {
                 y[i] = -1.0;
                 negHasExamples = true;
+                negCount++;
             }
             else
             {
@@ -131,11 +135,13 @@ bool DTBObliqueLS::adaptDataAndLabelForMultiClass (
             {
                 y[i] = 1.0;
                 posHasExamples = true;
+                posCount++;
             }
             else
             {
                 y[i] = -1.0;
                 negHasExamples = true;
+                negCount++;
             }
         }
 
@@ -383,50 +389,51 @@ DecisionNode *DTBObliqueLS::buildRecursive(
     getDataAndLabel( fp, examples, examples_selection, X, y, weights );
 
     // Transforming into multi-class problem
-    for ( int posClass = 0; posClass <= maxClassNo; posClass++ )
+    bool hasExamples = false;
+    NICE::Vector yCur;
+    NICE::Matrix XCur;
+
+    while ( !hasExamples )
     {
-        bool gotInnerIteration = false;
-        for ( int negClass = 0; negClass <= maxClassNo; negClass++ )
-        {
-            if ( posClass == negClass ) continue;
+        int posClass, negClass;
 
-            NICE::Vector yCur = y;
-            NICE::Matrix XCur = X;
+        posClass = rand() % (maxClassNo+1);
+        negClass = posClass;
 
-            bool hasExamples = adaptDataAndLabelForMultiClass(
-                posClass, negClass, XCur, yCur );
+        while ( posClass == negClass )
+        {
+            negClass = rand() % (maxClassNo+1);
+        }
 
-            yCur *= weights;
+        yCur = y;
+        XCur = X;
 
-            // are there examples for positive and negative class?
-            if ( !hasExamples ) continue;
+        hasExamples = adaptDataAndLabelForMultiClass(
+            posClass, negClass, XCur, yCur );
+    }
 
-            // one-vs-all setting: only one iteration for inner loop
-            if ( !useOneVsOne && gotInnerIteration ) continue;
+    yCur *= weights;
 
-            // Preparing system of linear equations
-            NICE::Matrix XTXr, G, temp;
-            regularizeDataMatrix( XCur, XTXr, regularizationType, lambdaCurrent );
-            choleskyDecomp(XTXr, G);
-            choleskyInvert(G, XTXr);
-            temp = XTXr * XCur.transpose();
+    // Preparing system of linear equations
+    NICE::Matrix XTXr, G, temp;
+    regularizeDataMatrix( XCur, XTXr, regularizationType, lambdaCurrent );
+    choleskyDecomp(XTXr, G);
+    choleskyInvert(G, XTXr);
+    temp = XTXr * XCur.transpose();
 
-            // Solve system of linear equations in a least squares manner
-            beta.multiply(temp,yCur,false);
+    // Solve system of linear equations in a least squares manner
+    beta.multiply(temp,yCur,false);
 
-            // Updating parameter vector in convolutional feature
-            f->setParameterVector( beta );
+    // Updating parameter vector in convolutional feature
+    f->setParameterVector( beta );
 
-            // Feature Values
-            values.clear();
-            f->calcFeatureValues( examples, examples_selection, values);
+    // Feature Values
+    values.clear();
+    f->calcFeatureValues( examples, examples_selection, values);
 
-            // complete search for threshold
-            findBestSplitThreshold ( values, bestSplitInfo, beta, e, maxClassNo );
+    // complete search for threshold
+    findBestSplitThreshold ( values, bestSplitInfo, beta, e, maxClassNo );
 
-            gotInnerIteration = true;
-        }
-    }
 //    f->setRandomParameterVector();
 //    beta = f->getParameterVector();
 //    f->calcFeatureValues( examples, examples_selection, values);