GPHIKRawClassifierMex.cpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499
  1. #ifdef NICE_USELIB_MEX
  2. /**
  3. * @file GPHIKRawClassifierMex.cpp
  4. * @author Alexander Freytag
  5. * @date 21-09-2015 (dd-mm-yyyy)
  6. * @brief Matlab-Interface of our GPHIKRawClassifier, allowing for training and classification without more advanced methods.
  7. */
  8. // STL includes
  9. #include <math.h>
  10. #include <matrix.h>
  11. #include <mex.h>
  12. // NICE-core includes
  13. #include <core/basics/Config.h>
  14. #include <core/basics/Timer.h>
  15. #include <core/vector/MatrixT.h>
  16. #include <core/vector/VectorT.h>
  17. // gp-hik-core includes
  18. #include "gp-hik-core/GPHIKRawClassifier.h"
  19. // Interface for conversion between Matlab and C objects
  20. #include "gp-hik-core/matlab/classHandleMtoC.h"
  21. #include "gp-hik-core/matlab/ConverterMatlabToNICE.h"
  22. #include "gp-hik-core/matlab/ConverterNICEToMatlab.h"
  23. using namespace std; //C basics
  24. using namespace NICE; // nice-core
  25. NICE::Config parseParametersGPHIKRawClassifier(const mxArray *prhs[], int nrhs)
  26. {
  27. NICE::Config conf;
  28. // if first argument is the filename of an existing config file,
  29. // read the config accordingly
  30. int i_start ( 0 );
  31. std::string variable = MatlabConversion::convertMatlabToString(prhs[i_start]);
  32. if(variable == "conf")
  33. {
  34. conf = NICE::Config ( MatlabConversion::convertMatlabToString( prhs[i_start+1] ) );
  35. i_start = i_start+2;
  36. }
  37. // now run over all given parameter specifications
  38. // and add them to the config
  39. for( int i=i_start; i < nrhs; i+=2 )
  40. {
  41. std::string variable = MatlabConversion::convertMatlabToString(prhs[i]);
  42. /////////////////////////////////////////
  43. // READ STANDARD BOOLEAN VARIABLES
  44. /////////////////////////////////////////
  45. if( (variable == "verbose") ||
  46. (variable == "debug") ||
  47. (variable == "ils_verbose")
  48. )
  49. {
  50. if ( mxIsChar( prhs[i+1] ) )
  51. {
  52. string value = MatlabConversion::convertMatlabToString( prhs[i+1] );
  53. if ( (value != "true") && (value != "false") )
  54. {
  55. std::string errorMsg = "Unexpected parameter value for \'" + variable + "\'. In string modus, \'true\' or \'false\' expected.";
  56. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  57. }
  58. if( value == "true" )
  59. conf.sB("GPHIKRawClassifier", variable, true);
  60. else
  61. conf.sB("GPHIKRawClassifier", variable, false);
  62. }
  63. else if ( mxIsLogical( prhs[i+1] ) )
  64. {
  65. bool value = MatlabConversion::convertMatlabToBool( prhs[i+1] );
  66. conf.sB("GPHIKRawClassifier", variable, value);
  67. }
  68. else
  69. {
  70. std::string errorMsg = "Unexpected parameter value for \'" + variable + "\'. \'true\', \'false\', or logical expected.";
  71. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  72. }
  73. }
  74. /////////////////////////////////////////
  75. // READ STANDARD INT VARIABLES
  76. /////////////////////////////////////////
  77. /////////////////////////////////////////
  78. // READ STRICT POSITIVE INT VARIABLES
  79. /////////////////////////////////////////
  80. if ( ( variable == "ils_max_iterations" )||
  81. ( variable == "eig_value_max_iterations" )
  82. )
  83. {
  84. if ( mxIsDouble( prhs[i+1] ) )
  85. {
  86. double value = MatlabConversion::convertMatlabToDouble(prhs[i+1]);
  87. if( value < 1 )
  88. {
  89. std::string errorMsg = "Expected parameter value larger than 0 for \'" + variable + "\'.";
  90. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  91. }
  92. conf.sI("GPHIKRawClassifier", variable, (int) value);
  93. }
  94. else if ( mxIsInt32( prhs[i+1] ) )
  95. {
  96. int value = MatlabConversion::convertMatlabToInt32(prhs[i+1]);
  97. if( value < 1 )
  98. {
  99. std::string errorMsg = "Expected parameter value larger than 0 for \'" + variable + "\'.";
  100. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  101. }
  102. conf.sI("GPHIKRawClassifier", variable, value);
  103. }
  104. else
  105. {
  106. std::string errorMsg = "Unexpected parameter value for \'" + variable + "\'. Int32 or Double expected.";
  107. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  108. }
  109. }
  110. /////////////////////////////////////////
  111. // READ STANDARD DOUBLE VARIABLES
  112. /////////////////////////////////////////
  113. /////////////////////////////////////////
  114. // READ POSITIVE DOUBLE VARIABLES
  115. /////////////////////////////////////////
  116. if ( (variable == "f_tolerance") ||
  117. (variable == "ils_min_delta") ||
  118. (variable == "ils_min_residual") ||
  119. (variable == "noise")
  120. )
  121. {
  122. if ( mxIsDouble( prhs[i+1] ) )
  123. {
  124. double value = MatlabConversion::convertMatlabToDouble(prhs[i+1]);
  125. if( value < 0.0 )
  126. {
  127. std::string errorMsg = "Expected parameter value larger than 0 for \'" + variable + "\'.";
  128. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  129. }
  130. conf.sD("GPHIKRawClassifier", variable, value);
  131. }
  132. else
  133. {
  134. std::string errorMsg = "Unexpected parameter value for \'" + variable + "\'. Double expected.";
  135. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  136. }
  137. }
  138. /////////////////////////////////////////
  139. // READ REMAINING SPECIFIC VARIABLES
  140. /////////////////////////////////////////
  141. if(variable == "ils_method")
  142. {
  143. string value = MatlabConversion::convertMatlabToString(prhs[i+1]);
  144. if(value != "CG" && value != "CGL" && value != "SYMMLQ" && value != "MINRES")
  145. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'ils_method\'. \'CG\', \'CGL\', \'SYMMLQ\' or \'MINRES\' expected.");
  146. conf.sS("GPHIKRawClassifier", variable, value);
  147. }
  148. }
  149. return conf;
  150. }
  151. // MAIN MATLAB FUNCTION
  152. void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
  153. {
  154. // get the command string specifying what to do
  155. if (nrhs < 1)
  156. mexErrMsgTxt("No commands and options passed... Aborting!");
  157. if( !mxIsChar( prhs[0] ) )
  158. mexErrMsgTxt("First argument needs to be the command, ie.e, the class method to call... Aborting!");
  159. std::string cmd = MatlabConversion::convertMatlabToString( prhs[0] );
  160. // create object
  161. if ( !strcmp("new", cmd.c_str() ) )
  162. {
  163. // check output variable
  164. if (nlhs != 1)
  165. mexErrMsgTxt("New: One output expected.");
  166. // read config settings
  167. NICE::Config conf = parseParametersGPHIKRawClassifier(prhs+1,nrhs-1);
  168. // create class instance
  169. NICE::GPHIKRawClassifier * classifier = new NICE::GPHIKRawClassifier ( &conf, "GPHIKRawClassifier" /*sectionName in config*/ );
  170. // handle to the C++ instance
  171. plhs[0] = MatlabConversion::convertPtr2Mat<NICE::GPHIKRawClassifier>( classifier );
  172. return;
  173. }
  174. // in all other cases, there should be a second input,
  175. // which the be the class instance handle
  176. if (nrhs < 2)
  177. mexErrMsgTxt("Second input should be a class instance handle.");
  178. // delete object
  179. if ( !strcmp("delete", cmd.c_str() ) )
  180. {
  181. // Destroy the C++ object
  182. MatlabConversion::destroyObject<NICE::GPHIKRawClassifier>(prhs[1]);
  183. return;
  184. }
  185. // get the class instance pointer from the second input
  186. // every following function needs the classifier object
  187. NICE::GPHIKRawClassifier * classifier = MatlabConversion::convertMat2Ptr<NICE::GPHIKRawClassifier>(prhs[1]);
  188. ////////////////////////////////////////
  189. // Check which class method to call //
  190. ////////////////////////////////////////
  191. // standard train - assumes initialized object
  192. if (!strcmp("train", cmd.c_str() ))
  193. {
  194. // Check parameters
  195. if (nlhs < 0 || nrhs < 4)
  196. {
  197. mexErrMsgTxt("Train: Unexpected arguments.");
  198. }
  199. //------------- read the data --------------
  200. std::vector< const NICE::SparseVector *> examplesTrain;
  201. NICE::Vector yMultiTrain;
  202. if ( mxIsSparse( prhs[2] ) )
  203. {
  204. examplesTrain = MatlabConversion::convertSparseMatrixToNice( prhs[2] );
  205. }
  206. else
  207. {
  208. NICE::Matrix dataTrain;
  209. dataTrain = MatlabConversion::convertDoubleMatrixToNice(prhs[2]);
  210. //----------------- convert data to sparse data structures ---------
  211. examplesTrain.resize( dataTrain.rows() );
  212. std::vector< const NICE::SparseVector *>::iterator exTrainIt = examplesTrain.begin();
  213. for (int i = 0; i < (int)dataTrain.rows(); i++, exTrainIt++)
  214. {
  215. *exTrainIt = new NICE::SparseVector( dataTrain.getRow(i) );
  216. }
  217. }
  218. yMultiTrain = MatlabConversion::convertDoubleVectorToNice(prhs[3]);
  219. //----------------- train our classifier -------------
  220. classifier->train ( examplesTrain , yMultiTrain );
  221. //----------------- clean up -------------
  222. for(int i=0;i<examplesTrain.size();i++)
  223. delete examplesTrain[i];
  224. return;
  225. }
  226. // Classify
  227. if ( !strcmp("classify", cmd.c_str() ) )
  228. {
  229. // Check parameters
  230. if ( (nlhs < 0) || (nrhs < 2) )
  231. {
  232. mexErrMsgTxt("Test: Unexpected arguments.");
  233. }
  234. //------------- read the data --------------
  235. uint result;
  236. NICE::SparseVector scores;
  237. if ( mxIsSparse( prhs[2] ) )
  238. {
  239. NICE::SparseVector * example;
  240. example = new NICE::SparseVector ( MatlabConversion::convertSparseVectorToNice( prhs[2] ) );
  241. classifier->classify ( example, result, scores );
  242. //----------------- clean up -------------
  243. delete example;
  244. }
  245. else
  246. {
  247. NICE::Vector * example;
  248. example = new NICE::Vector ( MatlabConversion::convertDoubleVectorToNice(prhs[2]) );
  249. NICE::SparseVector * svec = new NICE::SparseVector( *example );
  250. delete example;
  251. classifier->classify ( svec, result, scores );
  252. //----------------- clean up -------------
  253. delete svec;
  254. }
  255. // output
  256. plhs[0] = mxCreateDoubleScalar( result );
  257. if(nlhs >= 2)
  258. {
  259. plhs[1] = MatlabConversion::convertSparseVectorFromNice( scores, true /*b_adaptIndex*/);
  260. }
  261. return;
  262. }
  263. // Test - evaluate classifier on whole test set
  264. if ( !strcmp("test", cmd.c_str() ) )
  265. {
  266. // Check parameters
  267. if (nlhs < 0 || nrhs < 4)
  268. mexErrMsgTxt("Test: Unexpected arguments.");
  269. //------------- read the data --------------
  270. bool dataIsSparse ( mxIsSparse( prhs[2] ) );
  271. std::vector< const NICE::SparseVector *> dataTest_sparse;
  272. NICE::Matrix dataTest_dense;
  273. if ( dataIsSparse )
  274. {
  275. dataTest_sparse = MatlabConversion::convertSparseMatrixToNice( prhs[2] );
  276. }
  277. else
  278. {
  279. dataTest_dense = MatlabConversion::convertDoubleMatrixToNice(prhs[2]);
  280. }
  281. NICE::Vector yMultiTest;
  282. yMultiTest = MatlabConversion::convertDoubleVectorToNice(prhs[3]);
  283. // ------------------------------------------
  284. // ------------- PREPARATION --------------
  285. // ------------------------------------------
  286. // determine classes known during training and corresponding mapping
  287. // thereby allow for non-continous class labels
  288. std::set< uint > classesKnownTraining = classifier->getKnownClassNumbers();
  289. uint noClassesKnownTraining ( classesKnownTraining.size() );
  290. std::map< uint, uint > mapClNoToIdxTrain;
  291. std::set< uint >::const_iterator clTrIt = classesKnownTraining.begin();
  292. for ( uint i=0; i < noClassesKnownTraining; i++, clTrIt++ )
  293. mapClNoToIdxTrain.insert ( std::pair< uint, uint > ( *clTrIt, i ) );
  294. // determine classes known during testing and corresponding mapping
  295. // thereby allow for non-continous class labels
  296. std::set< uint > classesKnownTest;
  297. classesKnownTest.clear();
  298. // determine which classes we have in our label vector
  299. // -> MATLAB: myClasses = unique(y);
  300. for ( NICE::Vector::const_iterator it = yMultiTest.begin(); it != yMultiTest.end(); it++ )
  301. {
  302. if ( classesKnownTest.find ( *it ) == classesKnownTest.end() )
  303. {
  304. classesKnownTest.insert ( *it );
  305. }
  306. }
  307. int noClassesKnownTest ( classesKnownTest.size() );
  308. std::map< uint, uint> mapClNoToIdxTest;
  309. std::set< uint >::const_iterator clTestIt = classesKnownTest.begin();
  310. for ( uint i=0; i < noClassesKnownTest; i++, clTestIt++ )
  311. mapClNoToIdxTest.insert ( std::pair< uint, uint > ( *clTestIt, i ) );
  312. int i_numTestSamples;
  313. if ( dataIsSparse )
  314. i_numTestSamples = dataTest_sparse.size();
  315. else
  316. i_numTestSamples = (int) dataTest_dense.rows();
  317. NICE::Matrix confusionMatrix( noClassesKnownTraining, noClassesKnownTest, 0.0);
  318. NICE::Matrix scores( i_numTestSamples, noClassesKnownTraining, 0.0);
  319. // ------------------------------------------
  320. // ------------- CLASSIFICATION --------------
  321. // ------------------------------------------
  322. NICE::Timer t;
  323. double testTime (0.0);
  324. for (int i = 0; i < i_numTestSamples; i++)
  325. {
  326. //----------------- convert data to sparse data structures ---------
  327. uint result;
  328. NICE::SparseVector exampleScoresSparse;
  329. if ( dataIsSparse )
  330. {
  331. // and classify
  332. t.start();
  333. classifier->classify( dataTest_sparse[ i ], result, exampleScoresSparse );
  334. t.stop();
  335. testTime += t.getLast();
  336. }
  337. else
  338. {
  339. NICE::Vector example ( dataTest_dense.getRow(i) );
  340. NICE::SparseVector * svec = new NICE::SparseVector ( example );
  341. // and classify
  342. t.start();
  343. classifier->classify( svec, result, exampleScoresSparse );
  344. t.stop();
  345. testTime += t.getLast();
  346. delete svec;
  347. }
  348. confusionMatrix( mapClNoToIdxTrain.find(result)->second, mapClNoToIdxTest.find(yMultiTest[i])->second ) += 1.0;
  349. int scoreCnt ( 0 );
  350. for ( NICE::SparseVector::const_iterator scoreIt = exampleScoresSparse.begin(); scoreIt != exampleScoresSparse.end(); scoreIt++, scoreCnt++ )
  351. {
  352. scores(i,scoreCnt) = scoreIt->second;
  353. }
  354. }
  355. std::cerr << "Time for testing: " << testTime << std::endl;
  356. // clean up
  357. if ( dataIsSparse )
  358. {
  359. for ( std::vector<const NICE::SparseVector *>::iterator it = dataTest_sparse.begin(); it != dataTest_sparse.end(); it++)
  360. delete *it;
  361. }
  362. confusionMatrix.normalizeColumnsL1();
  363. double recRate = confusionMatrix.trace()/confusionMatrix.cols();
  364. plhs[0] = mxCreateDoubleScalar( recRate );
  365. if(nlhs >= 2)
  366. plhs[1] = MatlabConversion::convertMatrixFromNice(confusionMatrix);
  367. if(nlhs >= 3)
  368. plhs[2] = MatlabConversion::convertMatrixFromNice(scores);
  369. return;
  370. }
  371. ///////////////////// INTERFACE ONLINE LEARNABLE /////////////////////
  372. // interface specific methods for incremental extensions
  373. ///////////////////// INTERFACE ONLINE LEARNABLE /////////////////////
  374. // not supported here
  375. ///////////////////// INTERFACE PERSISTENT /////////////////////
  376. // interface specific methods for store and restore
  377. ///////////////////// INTERFACE PERSISTENT /////////////////////
  378. // not supported here
  379. // Got here, so command not recognized
  380. std::string errorMsg (cmd.c_str() );
  381. errorMsg += " -- command not recognized.";
  382. mexErrMsgTxt( errorMsg.c_str() );
  383. }
  384. #endif