Browse Source

k-Nearest-Neighbor regression now uses weighted average for prediction.

Frank Prüfer 11 years ago
parent
commit
232f7d3b6b
2 changed files with 45 additions and 29 deletions
  1. 42 26
      regression/npregression/RegKNN.cpp
  2. 3 3
      regression/npregression/RegKNN.h

+ 42 - 26
regression/npregression/RegKNN.cpp

@@ -34,36 +34,37 @@ void RegKNN::teach ( const NICE::VVector & _dataSet, const NICE::Vector & _label
     //NOTE this is crucial if we clear _teachSet afterwards!
     //therefore, take care NOT to call _techSet.clear() somewhere out of this method
     this->dataSet = _dataSet;
-    this->labelSet = _labelSet;
+    this->labelSet = _labelSet.std_vector();
     
     std::cerr << "number of known training samples: " << this->dataSet.size() << std::endl;   
     
 }
 
-// void RegKNN::teach ( const NICE::Vector & x, const double & y )
-// {
-//     std::cerr << "RegKNN::teach one new example" << std::endl;
-//     
-//     for ( size_t i = 0 ; i < x.size() ; i++ )
-//       if ( isnan(x[i]) ) 
-//       {
-//           fprintf (stderr, "There is a NAN value in within this vector: x[%d] = %f\n", (int)i, x[i]);
-//           cerr << x << endl;
-//           exit(-1);
-//       }
-// 
-//     dataSet.push_back ( x );
-//     labelSet.push_back ( y );
-//     
-//     std::cerr << "number of known training samples: " << dataSet.size()<< std::endl;
-// }
+void RegKNN::teach ( const NICE::Vector & x, const double & y )
+{
+    std::cerr << "RegKNN::teach one new example" << std::endl;
+    
+    for ( size_t i = 0 ; i < x.size() ; i++ )
+      if ( isnan(x[i]) ) 
+      {
+          fprintf (stderr, "There is a NAN value in within this vector: x[%d] = %f\n", (int)i, x[i]);
+          cerr << x << endl;
+          exit(-1);
+      }
+
+    dataSet.push_back ( x );
+    
+    labelSet.push_back ( y );
+    
+    std::cerr << "number of known training samples: " << dataSet.size()<< std::endl;
+}
 
 double RegKNN::predict ( const NICE::Vector & x )
 {
     FullVector distances(dataSet.size());
 
     if ( dataSet.size() <= 0 ) {
-		fprintf (stderr, "RegKNN: please train this classifier before classifying\n");
+		fprintf (stderr, "RegKNN: please use the train method first\n");
 		exit(-1);
     }
 
@@ -71,23 +72,38 @@ double RegKNN::predict ( const NICE::Vector & x )
     
       double distance = distancefunc->calculate (x,dataSet[i]);
       
-      if ( isnan(distance) )
-      {
-          fprintf (stderr, "RegKNN::classify: NAN value found !!\n");
+      if ( isnan(distance) ){
+          fprintf (stderr, "RegKNN::predict: NAN value found !!\n");
           cerr << x << endl;
           cerr << dataSet[i] << endl;
       }
-      distances[i] = distance;
-      
+      distances[i] = distance;     
     }
+    
     std::vector<int> ind;
     distances.getSortedIndices(ind);
     
     double response = 0.0;  
     
+    if ( dataSet.size() < K ){
+      K = dataSet.size();
+      cerr<<"RegKNN: Not enough datapoints! Setting K to: "<< K <<endl;
+    }
+        
+    if ( distances[ind[0]] == 0.0 ) {
+      cerr<<"RegKNN: Warning: datapoint was already seen during training... using its label as prediction."<<endl;
+      return labelSet[ind[0]];  
+    }
+    
+    double maxElement = distances.max();	//normalize distances
+    distances.multiply(1.0/maxElement);
+    
+    double weightSum = 0.0;
+    
     for(uint i = 0; i < K; i++){
-      response += labelSet[ind[i]];
+      response += 1.0/distances[ind[i]] * labelSet[ind[i]];
+      weightSum += 1.0/distances[ind[i]];
     }
        
-    return (response / (double) K);
+    return ( response / weightSum );
 }

+ 3 - 3
regression/npregression/RegKNN.h

@@ -29,7 +29,7 @@ class RegKNN : public RegressionAlgorithm
     NICE::VVector dataSet;
     
     /** set of responses according to dataset */
-    NICE::Vector labelSet;
+    std::vector<double> labelSet;
     
     /** used distance function */
     NICE::VectorDistance<double> *distancefunc;
@@ -47,8 +47,8 @@ class RegKNN : public RegressionAlgorithm
     /** teach whole set at once */
     void teach ( const NICE::VVector & dataSet, const NICE::Vector & labelSet );
 
-//     /** teach one data point at a time */
-//     void teach ( const NICE::Vector & x, const double & y );
+    /** teach one data point at a time */
+    void teach ( const NICE::Vector & x, const double & y );
 };
 }	//namespace