KernelClassifier.cpp 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. /**
  2. * @file KernelClassifier.cpp
  3. * @brief classifier interface for kernel based methods
  4. * @author Erik Rodner
  5. * @date 12/02/2009
  6. */
  7. #include <iostream>
  8. #include "KernelClassifier.h"
  9. using namespace std;
  10. using namespace NICE;
  11. using namespace OBJREC;
  12. #undef DEBUG
  13. KernelClassifier::KernelClassifier( const Config *_conf, Kernel *kernelFunction, int normalizationType ) : conf(*_conf)
  14. {
  15. this->kernelFunction = kernelFunction;
  16. this->normalizationType = normalizationType;
  17. }
  18. KernelClassifier::KernelClassifier ( const KernelClassifier & src ) : VecClassifier(), conf(src.conf)
  19. {
  20. if ( src.kernelFunction != NULL )
  21. this->kernelFunction = src.kernelFunction->clone();
  22. else
  23. this->kernelFunction = NULL;
  24. this->vecSetLabels = src.vecSetLabels;
  25. this->normalizationType = src.normalizationType;
  26. this->vecSet = src.vecSet;
  27. }
  28. KernelClassifier::~KernelClassifier()
  29. {
  30. }
  31. ClassificationResult KernelClassifier::classify ( const NICE::Vector & x ) const
  32. {
  33. if ( kernelFunction == NULL )
  34. fthrow( Exception, "KernelClassifier::classify: To use this function, you have to specify a kernel function using the constructor" );
  35. NICE::Vector kstar;
  36. // nothing happens, we should create a copy of x and then
  37. // use normalization
  38. //if ( this->normalizationType == KERNELCLASSIFIER_NORMALIZATION_EUCLIDEAN )
  39. // x.normL2();
  40. kernelFunction->calcKernelVector ( vecSet, x, kstar );
  41. #ifdef DEBUG
  42. if ( kstar.normL2() == 0.0 )
  43. cerr << "The kernel vector k_* has zero norm, you might want to adjust kernel parameters with e.g. Kernel:log_rbf_gamma." << endl;
  44. #endif
  45. double kernelSelf = kernelFunction->K(x,x);
  46. return classifyKernel (kstar, kernelSelf);
  47. }
  48. void KernelClassifier::teach ( const LabeledSetVector & teachSet )
  49. {
  50. if ( kernelFunction == NULL )
  51. fthrow( Exception, "KernelClassifier::teach: To use this function, you have to specify a kernel function using the constructor" );
  52. teachSet.getFlatRepresentation ( vecSet, vecSetLabels );
  53. if ( this->normalizationType == KERNELCLASSIFIER_NORMALIZATION_EUCLIDEAN )
  54. {
  55. for ( VVector::iterator i = vecSet.begin(); i != vecSet.end(); i++ )
  56. {
  57. Vector & v = *i;
  58. v.normL2();
  59. }
  60. }
  61. KernelData *kernelData = new KernelData ( &conf );
  62. kernelFunction->calcKernelData ( vecSet, kernelData );
  63. teach ( kernelData, vecSetLabels );
  64. // maybe problems with other kernel classifiers
  65. delete kernelData;
  66. }
  67. void KernelClassifier::restore(std::istream& ifs, int type)
  68. {
  69. vecSet.restore(ifs);
  70. //TODO: read and write kernel
  71. }
  72. void KernelClassifier::store(std::ostream& ofs, int type) const
  73. {
  74. vecSet.store(ofs);
  75. }