genericClassifierSelection.h 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. #ifndef _NICE_OBJREC_GENERICCLASSIFIERSELECTION_INCLUDE
  2. #define _NICE_OBJREC_GENERICCLASSIFIERSELECTION_INCLUDE
  3. //STL
  4. #include <vector>
  5. //core
  6. #include "core/basics/StringTools.h"
  7. //vislearning -- vector classifiers
  8. #include "vislearning/classifier/vclassifier/VCAmitSVM.h"
  9. #include "vislearning/classifier/vclassifier/VCNearestClassMean.h"
  10. #include "vislearning/classifier/vclassifier/VCSimpleGaussian.h"
  11. #include "vislearning/classifier/vclassifier/VCNearestNeighbour.h"
  12. #include "vislearning/classifier/vclassifier/VCCrossGeneralization.h"
  13. #include "vislearning/classifier/classifierinterfaces/VCFeaturePool.h"
  14. #include "vislearning/classifier/vclassifier/VCOneVsOne.h"
  15. #include "vislearning/classifier/vclassifier/VCOneVsAll.h"
  16. #include "vislearning/classifier/vclassifier/VCDTSVM.h"
  17. #include "vislearning/classifier/vclassifier/VCTransform.h"
  18. //vislearning -- kernel classifiers
  19. #include "vislearning/classifier/kernelclassifier/KCGPRegression.h"
  20. #include "vislearning/classifier/kernelclassifier/KCGPLaplace.h"
  21. #include "vislearning/classifier/kernelclassifier/KCGPLaplaceOneVsAll.h"
  22. #include "vislearning/classifier/kernelclassifier/KCOneVsAll.h"
  23. #include "vislearning/classifier/kernelclassifier/KCGPRegOneVsAll.h"
  24. #include "vislearning/classifier/kernelclassifier/KCMinimumEnclosingBall.h"
  25. #include "vislearning/classifier/kernelclassifier/KCGPOneClass.h"
  26. //vislearning -- kernels
  27. #include "vislearning/math/kernels/KernelStd.h"
  28. #include "vislearning/math/kernels/KernelExp.h"
  29. #include "vislearning/math/kernels/KernelRBF.h"
  30. #include "vislearning/math/kernels/genericKernel.h"
  31. //vislearning -- feature pool classifier
  32. #include "vislearning/classifier/fpclassifier/boosting/FPCBoosting.h"
  33. #include "vislearning/classifier/fpclassifier/randomforest/FPCRandomForests.h"
  34. #include "vislearning/classifier/fpclassifier/randomforest/FPCDecisionTree.h"
  35. #include "vislearning/classifier/fpclassifier/logisticregression/FPCSMLR.h"
  36. #include "vislearning/classifier/fpclassifier/gphik/FPCGPHIK.h"
  37. //vislearning -- classifier combinations
  38. #include "vislearning/classifier/classifiercombination/VCPreRandomForest.h"
  39. //vislearning -- SVM-based classifiers (vclassifier, kernelclassifier)
  40. #ifdef NICE_USELIB_SVMLIGHT
  41. #include "vislearning/classifier/vclassifier/VCSVMLight.h"
  42. #include "vislearning/classifier/vclassifier/VCSVMOneClass.h"
  43. #include "vislearning/classifier/kernelclassifier/KCSVMLight.h"
  44. #endif
  45. //external stuff
  46. #ifdef NICE_USELIB_NICEDTSVM
  47. #include "nice-dtsvm/VCTreeBasedClassifier.h"
  48. #endif
  49. // #include "gp-hik-exp/GPHIKClassifierNICE.h"
  50. namespace OBJREC {
  51. class GenericClassifierSelection
  52. {
  53. public:
  54. static VecClassifier *selectVecClassifier ( const NICE::Config *conf, std::string classifier_type )
  55. {
  56. std::vector<std::string> submatches;
  57. VecClassifier *classifier = NULL;
  58. if ( classifier_type == "amit" ) {
  59. classifier = new VCAmitSVM ( conf );
  60. } else if ( classifier_type == "nn" ) {
  61. classifier = new VCNearestNeighbour( conf, new NICE::EuclidianDistance<double>() );
  62. #ifdef NICE_USELIB_ICE
  63. } else if ( classifier_type == "gauss" ) {
  64. classifier = new VCSimpleGaussian( conf );
  65. } else if ( classifier_type == "nearest_classmean" ) {
  66. classifier = new VCNearestClassMean( conf, new NICE::EuclidianDistance<double>() );
  67. #endif
  68. }
  69. ////////////////////////////////////////
  70. // //
  71. // all Feature Pool Classifiers //
  72. // //
  73. ////////////////////////////////////////
  74. else if ( classifier_type == "GPHIK" ) {
  75. FeaturePoolClassifier *fpc = new FPCGPHIK ( conf, "GPHIK" );
  76. classifier = new VCFeaturePool ( conf, fpc );
  77. }
  78. else if ( classifier_type == "random_forest" ) {
  79. FeaturePoolClassifier *fpc = new FPCRandomForests ( conf, "RandomForest" );
  80. classifier = new VCFeaturePool ( conf, fpc );
  81. }
  82. else if ( classifier_type == "sparse_logistic_regression" ) {
  83. FeaturePoolClassifier *fpc = new FPCSMLR ( conf, "SparseLogisticRegression" );
  84. classifier = new VCFeaturePool ( conf, fpc );
  85. } else if ( classifier_type == "boost" ) {
  86. FeaturePoolClassifier *fpc = new FPCBoosting ( conf, "Boost" );
  87. classifier = new VCFeaturePool ( conf, fpc );
  88. } else if ( classifier_type == "decision_tree" ) {
  89. FeaturePoolClassifier *fpc = new FPCDecisionTree ( conf, "DecisionTree" );
  90. classifier = new VCFeaturePool ( conf, fpc );
  91. #ifdef NICE_USELIB_ICE
  92. } else if ( ( classifier_type == "cross_generalization" ) || ( classifier_type == "bart" ) ) {
  93. classifier = new VCCrossGeneralization ( conf );
  94. #endif
  95. #ifdef NICE_USELIB_SVMLIGHT
  96. } else if ( ( classifier_type == "svmlight" ) || ( classifier_type == "svm" ) ) {
  97. classifier = new VCSVMLight ( conf );
  98. } else if ( ( classifier_type == "svm_onevsone" ) ) {
  99. classifier = new VCOneVsOne ( conf, new VCSVMLight ( conf ) );
  100. } else if ( ( classifier_type == "svm_onevsall" ) ) {
  101. classifier = new VCOneVsAll ( conf, new VCSVMLight ( conf ) );
  102. } else if ( ( classifier_type == "svmlight_kernel" ) ) {
  103. classifier = new KCSVMLight ( conf, new KernelStd() );
  104. } else if ( ( classifier_type == "svm_one_class" ) ) {
  105. classifier = new VCSVMOneClass ( conf, "VCSVMLight" );
  106. #endif
  107. #ifdef NICE_USELIB_NICEDTSVM
  108. // this classifier requires nice-dtsvm, which is an optional
  109. // nice sub-library
  110. } else if ( classifier_type == "treebased" ) {
  111. classifier = new VCTreeBasedClassifier ( conf );
  112. #endif
  113. } else if ( ( classifier_type == "dtgp" ) ) {
  114. classifier = new VCDTSVM ( conf );
  115. } else if ( ( classifier_type == "minimum_enclosing_ball" ) ) {
  116. std::string kernel_type = conf->gS ( "Kernel", "kernel_function", "rbf" );
  117. classifier = new KCMinimumEnclosingBall ( conf, GenericKernelSelection::selectKernel ( conf, kernel_type ) );
  118. } else if ( ( classifier_type == "gp_one_class" ) ) {
  119. std::string kernel_type = conf->gS ( "Kernel", "kernel_function", "rbf" );
  120. classifier = new KCGPOneClass ( conf, GenericKernelSelection::selectKernel ( conf, kernel_type ) );
  121. } else if ( ( classifier_type == "gp_regression_rbf" ) ) {
  122. std::string kernel_type = conf->gS ( "Kernel", "kernel_function", "rbf" );
  123. classifier = new KCGPRegression ( conf, GenericKernelSelection::selectKernel ( conf, kernel_type ) );
  124. } else if ( ( classifier_type == "gp_laplace_rbf" ) ) {
  125. std::string kernel_type = conf->gS ( "Kernel", "kernel_function", "rbf" );
  126. classifier = new KCGPLaplace ( conf, GenericKernelSelection::selectKernel ( conf, kernel_type ) );
  127. } else if ( ( classifier_type == "gp_regression_rbf_onevsall" ) ) {
  128. std::string kernel_type = conf->gS ( "Kernel", "kernel_function", "rbf" );
  129. classifier = new KCGPRegOneVsAll ( conf, GenericKernelSelection::selectKernel ( conf, kernel_type ) );
  130. } else if ( ( classifier_type == "gp_laplace_rbf_onevsall" ) ) {
  131. std::string kernel_type = conf->gS ( "Kernel", "kernel_function", "rbf" );
  132. classifier = new KCGPLaplaceOneVsAll ( conf, GenericKernelSelection::selectKernel ( conf, kernel_type ) );
  133. } else if ( NICE::StringTools::regexMatch ( classifier_type, "^one_vs_one\\(([^\\)]+)\\)$", submatches ) && ( submatches.size() == 2 ) ) {
  134. classifier = new VCOneVsOne ( conf, selectVecClassifier ( conf, submatches[1] ) );
  135. } else if ( NICE::StringTools::regexMatch ( classifier_type, "^one_vs_all\\(([^\\)]+)\\)$", submatches ) && ( submatches.size() == 2 ) ) {
  136. classifier = new VCOneVsAll ( conf, selectVecClassifier ( conf, submatches[1] ) );
  137. } else if ( NICE::StringTools::regexMatch ( classifier_type, "^random_forest\\(([^\\)]+)\\)$", submatches ) && ( submatches.size() == 2 ) ) {
  138. classifier = new VCPreRandomForest ( conf, "VCPreRandomForest", selectVecClassifier ( conf, submatches[1] ) );
  139. } else {
  140. fthrow ( NICE::Exception, "Classifier type " << classifier_type << " not (yet) supported." << std::endl <<
  141. "(genericClassifierSelection.h contains a list of classifiers to choose from)" );
  142. }
  143. return classifier;
  144. }
  145. };
  146. }
  147. #endif