VCSVMLight.h 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. /**
  2. * @file VCSVMLight.h
  3. * @brief Interface to SVMLight from T. Joachims
  4. * @author Erik Rodner
  5. * @date 10/25/2007
  6. */
  7. #ifndef VCSVMLightINCLUDE
  8. #define VCSVMLightINCLUDE
  9. #ifdef NICE_USELIB_SVMLIGHT
  10. #include <vislearning/nice.h>
  11. extern "C" {
  12. #include <svm_common.h>
  13. #include <svm_learn.h>
  14. }
  15. #include "vislearning/cbaselib/LabeledSet.h"
  16. #include "vislearning/classifier/classifierbase/VecClassifier.h"
  17. #include "VCLogisticRegression.h"
  18. namespace OBJREC {
  19. /** Interface to SVMLight from T. Joachims */
  20. class VCSVMLight : public VecClassifier
  21. {
  22. protected:
  23. NICE::Vector max;
  24. NICE::Vector min;
  25. int normalization_type;
  26. enum {
  27. SVM_NORMALIZATION_EUCLIDEAN = 0,
  28. SVM_NORMALIZATION_01,
  29. SVM_NORMALIZATION_NONE
  30. };
  31. int kernel_type;
  32. /** @brief kernel parameters
  33. Use toyExample to see the effects of this parameter
  34. */
  35. double poly_degree;
  36. double rbf_gamma;
  37. double sigmoidpoly_scale;
  38. double sigmoidpoly_bias;
  39. /** regularization parameter of SVM */
  40. double svm_c;
  41. bool use_crossvalidation;
  42. bool optimize_parameters;
  43. /** cross validation settings */
  44. double rbf_gamma_min;
  45. double rbf_gamma_max;
  46. double rbf_gamma_step;
  47. double svm_c_min;
  48. double svm_c_max;
  49. double svm_c_step;
  50. /** SVMLight model */
  51. MODEL *finalModel;
  52. VCLogisticRegression *logreg;
  53. void readDocuments(
  54. DOC ***docs,
  55. double **label,
  56. long int *totwords,
  57. long int *totdoc,
  58. const LabeledSetVector & train );
  59. void svmLightTraining (
  60. const LabeledSetVector & trainSet );
  61. MODEL *singleTraining ( DOC **docs,
  62. double *target, long int totwords, long int totdoc,
  63. MODEL *model, KERNEL_PARM *kernel_parm,
  64. LEARN_PARM *learn_parm );
  65. void initParameters (
  66. LEARN_PARM *learn_parm,
  67. KERNEL_PARM *kernel_parm );
  68. void initMainParameters ( LEARN_PARM *learn_parm,
  69. KERNEL_PARM *kernel_parm );
  70. void estimateMaxMin ( const LabeledSetVector & train );
  71. void normalizeVector ( int normalization_type, NICE::Vector & x ) const;
  72. double getSVMScore ( const NICE::Vector & x ) const;
  73. MODEL *optimizeParameters ( DOC **docs,
  74. double *target, long int totwords, long int totdoc,
  75. MODEL *model, KERNEL_PARM *kernel_parm,
  76. LEARN_PARM *learn_parm );
  77. public:
  78. /** using a config file to read some settings */
  79. VCSVMLight( const NICE::Config *conf, const std::string & section = "VCSVMLight" );
  80. /** simple copy constructor */
  81. VCSVMLight ( const VCSVMLight & src );
  82. /** simple destructor */
  83. virtual ~VCSVMLight();
  84. /** classify using simple vector */
  85. ClassificationResult classify ( const NICE::Vector & x ) const;
  86. /** classify using a simple vector */
  87. void teach ( const LabeledSetVector & teachSet );
  88. void finishTeaching();
  89. /** clone this object */
  90. virtual VCSVMLight *clone(void) const;
  91. void clear ();
  92. void read (const std::string& s, int format = 0);
  93. void save (const std::string& s, int format = 0) const;
  94. void store ( std::ostream & os, int format = 0 ) const;
  95. void restore ( std::istream & is, int format = 0 );
  96. };
  97. } // namespace
  98. #endif
  99. #endif