123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293 |
- /**
- * @file RegKNN.cpp
- * @brief Implementation of k-Nearest-Neighbor algorithm for regression purposes
- * @author Frank Prüfer
- * @date 08/29/2013
- */
- #include <iostream>
- #include "vislearning/regression/npregression/RegKNN.h"
- #include "vislearning/math/mathbase/FullVector.h"
- using namespace OBJREC;
- using namespace std;
- using namespace NICE;
- RegKNN::RegKNN ( const Config *_conf, NICE::VectorDistance<double> *_distancefunc ) : distancefunc (_distancefunc)
- {
- K = _conf->gI("RegKNN", "K", 1 );
- if ( _distancefunc == NULL )
- distancefunc = new EuclidianDistance<double>();
- }
- RegKNN::~RegKNN()
- {
- }
- void RegKNN::teach ( const NICE::VVector & _dataSet, const NICE::Vector & _labelSet)
- {
- fprintf (stderr, "teach using all !\n");
- //NOTE this is crucial if we clear _teachSet afterwards!
- //therefore, take care NOT to call _techSet.clear() somewhere out of this method
- this->dataSet = _dataSet;
- this->labelSet = _labelSet;
-
- std::cerr << "number of known training samples: " << this->dataSet.size() << std::endl;
-
- }
- // void RegKNN::teach ( const NICE::Vector & x, const double & y )
- // {
- // std::cerr << "RegKNN::teach one new example" << std::endl;
- //
- // for ( size_t i = 0 ; i < x.size() ; i++ )
- // if ( isnan(x[i]) )
- // {
- // fprintf (stderr, "There is a NAN value in within this vector: x[%d] = %f\n", (int)i, x[i]);
- // cerr << x << endl;
- // exit(-1);
- // }
- //
- // dataSet.push_back ( x );
- // labelSet.push_back ( y );
- //
- // std::cerr << "number of known training samples: " << dataSet.size()<< std::endl;
- // }
- double RegKNN::predict ( const NICE::Vector & x )
- {
- FullVector distances(dataSet.size());
- if ( dataSet.size() <= 0 ) {
- fprintf (stderr, "RegKNN: please train this classifier before classifying\n");
- exit(-1);
- }
- for(uint i = 0; i < dataSet.size(); i++){
-
- double distance = distancefunc->calculate (x,dataSet[i]);
-
- if ( isnan(distance) )
- {
- fprintf (stderr, "RegKNN::classify: NAN value found !!\n");
- cerr << x << endl;
- cerr << dataSet[i] << endl;
- }
- distances[i] = distance;
-
- }
- std::vector<int> ind;
- distances.getSortedIndices(ind);
-
- double response = 0.0;
-
- for(uint i = 0; i < K; i++){
- response += labelSet[ind[i]];
- }
-
- return (response / (double) K);
- }
|