GPHIKClassifierMex.cpp 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752
  1. #ifdef NICE_USELIB_MEX
  2. /**
  3. * @file GPHIKClassifierMex.cpp
  4. * @author Alexander Freytag
  5. * @date 07-01-2014 (dd-mm-yyyy)
  6. * @brief Matlab-Interface of our GPHIKClassifier, allowing for training, classification, optimization, variance prediction, incremental learning, and storing/re-storing.
  7. */
  8. // STL includes
  9. #include <math.h>
  10. #include <matrix.h>
  11. #include <mex.h>
  12. // NICE-core includes
  13. #include <core/basics/Config.h>
  14. #include <core/basics/Timer.h>
  15. #include <core/vector/MatrixT.h>
  16. #include <core/vector/VectorT.h>
  17. // gp-hik-core includes
  18. #include "gp-hik-core/GPHIKClassifier.h"
  19. // Interface for conversion between Matlab and C objects
  20. #include "gp-hik-core/matlab/classHandleMtoC.h"
  21. #include "gp-hik-core/matlab/ConverterMatlabToNICE.h"
  22. #include "gp-hik-core/matlab/ConverterNICEToMatlab.h"
  23. using namespace std; //C basics
  24. using namespace NICE; // nice-core
  25. NICE::Config parseParametersGPHIKClassifier(const mxArray *prhs[], int nrhs)
  26. {
  27. NICE::Config conf;
  28. // if first argument is the filename of an existing config file,
  29. // read the config accordingly
  30. int i_start ( 0 );
  31. std::string variable = MatlabConversion::convertMatlabToString(prhs[i_start]);
  32. if(variable == "conf")
  33. {
  34. conf = NICE::Config ( MatlabConversion::convertMatlabToString( prhs[i_start+1] ) );
  35. i_start = i_start+2;
  36. }
  37. // now run over all given parameter specifications
  38. // and add them to the config
  39. for( int i=i_start; i < nrhs; i+=2 )
  40. {
  41. std::string variable = MatlabConversion::convertMatlabToString(prhs[i]);
  42. /////////////////////////////////////////
  43. // READ STANDARD BOOLEAN VARIABLES
  44. /////////////////////////////////////////
  45. if( (variable == "verboseTime") ||
  46. (variable == "verbose") ||
  47. (variable == "optimize_noise") ||
  48. (variable == "uncertaintyPredictionForClassification") ||
  49. (variable == "use_quantization") ||
  50. (variable == "ils_verbose")
  51. )
  52. {
  53. if ( mxIsChar( prhs[i+1] ) )
  54. {
  55. string value = MatlabConversion::convertMatlabToString( prhs[i+1] );
  56. if ( (value != "true") && (value != "false") )
  57. {
  58. std::string errorMsg = "Unexpected parameter value for \'" + variable + "\'. In string modus, \'true\' or \'false\' expected.";
  59. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  60. }
  61. if( value == "true" )
  62. conf.sB("GPHIKClassifier", variable, true);
  63. else
  64. conf.sB("GPHIKClassifier", variable, false);
  65. }
  66. else if ( mxIsLogical( prhs[i+1] ) )
  67. {
  68. bool value = MatlabConversion::convertMatlabToBool( prhs[i+1] );
  69. conf.sB("GPHIKClassifier", variable, value);
  70. }
  71. else
  72. {
  73. std::string errorMsg = "Unexpected parameter value for \'" + variable + "\'. \'true\', \'false\', or logical expected.";
  74. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  75. }
  76. }
  77. /////////////////////////////////////////
  78. // READ STANDARD INT VARIABLES
  79. /////////////////////////////////////////
  80. if ( (variable == "nrOfEigenvaluesToConsiderForVarApprox")
  81. )
  82. {
  83. if ( mxIsDouble( prhs[i+1] ) )
  84. {
  85. double value = MatlabConversion::convertMatlabToDouble(prhs[i+1]);
  86. conf.sI("GPHIKClassifier", variable, (int) value);
  87. }
  88. else if ( mxIsInt32( prhs[i+1] ) )
  89. {
  90. int value = MatlabConversion::convertMatlabToInt32(prhs[i+1]);
  91. conf.sI("GPHIKClassifier", variable, value);
  92. }
  93. else
  94. {
  95. std::string errorMsg = "Unexpected parameter value for \'" + variable + "\'. Int32 or Double expected.";
  96. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  97. }
  98. }
  99. /////////////////////////////////////////
  100. // READ STRICT POSITIVE INT VARIABLES
  101. /////////////////////////////////////////
  102. if ( (variable == "num_bins") ||
  103. (variable == "ils_max_iterations")
  104. )
  105. {
  106. if ( mxIsDouble( prhs[i+1] ) )
  107. {
  108. double value = MatlabConversion::convertMatlabToDouble(prhs[i+1]);
  109. if( value < 1 )
  110. {
  111. std::string errorMsg = "Expected parameter value larger than 0 for \'" + variable + "\'.";
  112. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  113. }
  114. conf.sI("GPHIKClassifier", variable, (int) value);
  115. }
  116. else if ( mxIsInt32( prhs[i+1] ) )
  117. {
  118. int value = MatlabConversion::convertMatlabToInt32(prhs[i+1]);
  119. if( value < 1 )
  120. {
  121. std::string errorMsg = "Expected parameter value larger than 0 for \'" + variable + "\'.";
  122. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  123. }
  124. conf.sI("GPHIKClassifier", variable, value);
  125. }
  126. else
  127. {
  128. std::string errorMsg = "Unexpected parameter value for \'" + variable + "\'. Int32 or Double expected.";
  129. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  130. }
  131. }
  132. /////////////////////////////////////////
  133. // READ STANDARD DOUBLE VARIABLES
  134. /////////////////////////////////////////
  135. if ( (variable == "parameter_upper_bound") ||
  136. (variable == "parameter_lower_bound")
  137. )
  138. {
  139. if ( mxIsDouble( prhs[i+1] ) )
  140. {
  141. double value = MatlabConversion::convertMatlabToDouble(prhs[i+1]);
  142. conf.sD("GPHIKClassifier", variable, value);
  143. }
  144. else
  145. {
  146. std::string errorMsg = "Unexpected parameter value for \'" + variable + "\'. Double expected.";
  147. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  148. }
  149. }
  150. /////////////////////////////////////////
  151. // READ POSITIVE DOUBLE VARIABLES
  152. /////////////////////////////////////////
  153. if ( (variable == "ils_min_delta") ||
  154. (variable == "ils_min_residual") ||
  155. (variable == "noise")
  156. )
  157. {
  158. if ( mxIsDouble( prhs[i+1] ) )
  159. {
  160. double value = MatlabConversion::convertMatlabToDouble(prhs[i+1]);
  161. if( value < 0.0 )
  162. {
  163. std::string errorMsg = "Expected parameter value larger than 0 for \'" + variable + "\'.";
  164. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  165. }
  166. conf.sD("GPHIKClassifier", variable, value);
  167. }
  168. else
  169. {
  170. std::string errorMsg = "Unexpected parameter value for \'" + variable + "\'. Double expected.";
  171. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  172. }
  173. }
  174. /////////////////////////////////////////
  175. // READ REMAINING SPECIFIC VARIABLES
  176. /////////////////////////////////////////
  177. if(variable == "ils_method")
  178. {
  179. string value = MatlabConversion::convertMatlabToString(prhs[i+1]);
  180. if(value != "CG" && value != "CGL" && value != "SYMMLQ" && value != "MINRES")
  181. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'ils_method\'. \'CG\', \'CGL\', \'SYMMLQ\' or \'MINRES\' expected.");
  182. conf.sS("GPHIKClassifier", variable, value);
  183. }
  184. if(variable == "optimization_method")
  185. {
  186. string value = MatlabConversion::convertMatlabToString(prhs[i+1]);
  187. if(value != "greedy" && value != "downhillsimplex" && value != "none")
  188. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'optimization_method\'. \'greedy\', \'downhillsimplex\' or \'none\' expected.");
  189. conf.sS("GPHIKClassifier", variable, value);
  190. }
  191. if(variable == "transform")
  192. {
  193. string value = MatlabConversion::convertMatlabToString( prhs[i+1] );
  194. if( value != "identity" && value != "absexp" && value != "exp" && value != "MKL" && value != "WeightedDim")
  195. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'transform\'. \'identity\', \'absexp\', \'exp\' , \'MKL\' or \'WeightedDim\' expected.");
  196. conf.sS("GPHIKClassifier", variable, value);
  197. }
  198. if(variable == "varianceApproximation")
  199. {
  200. string value = MatlabConversion::convertMatlabToString(prhs[i+1]);
  201. if(value != "approximate_fine" && value != "approximate_rough" && value != "exact" && value != "none")
  202. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'varianceApproximation\'. \'approximate_fine\', \'approximate_rough\', \'none\' or \'exact\' expected.");
  203. conf.sS("GPHIKClassifier", variable, value);
  204. }
  205. }
  206. return conf;
  207. }
  208. // MAIN MATLAB FUNCTION
  209. void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
  210. {
  211. // get the command string specifying what to do
  212. if (nrhs < 1)
  213. mexErrMsgTxt("No commands and options passed... Aborting!");
  214. if( !mxIsChar( prhs[0] ) )
  215. mexErrMsgTxt("First argument needs to be the command, ie.e, the class method to call... Aborting!");
  216. std::string cmd = MatlabConversion::convertMatlabToString( prhs[0] );
  217. // create object
  218. if ( !strcmp("new", cmd.c_str() ) )
  219. {
  220. // check output variable
  221. if (nlhs != 1)
  222. mexErrMsgTxt("New: One output expected.");
  223. // read config settings
  224. NICE::Config conf = parseParametersGPHIKClassifier(prhs+1,nrhs-1);
  225. // create class instance
  226. NICE::GPHIKClassifier * classifier = new NICE::GPHIKClassifier ( &conf, "GPHIKClassifier" /*sectionName in config*/ );
  227. // handle to the C++ instance
  228. plhs[0] = MatlabConversion::convertPtr2Mat<NICE::GPHIKClassifier>( classifier );
  229. return;
  230. }
  231. // in all other cases, there should be a second input,
  232. // which the be the class instance handle
  233. if (nrhs < 2)
  234. mexErrMsgTxt("Second input should be a class instance handle.");
  235. // delete object
  236. if ( !strcmp("delete", cmd.c_str() ) )
  237. {
  238. // Destroy the C++ object
  239. MatlabConversion::destroyObject<NICE::GPHIKClassifier>(prhs[1]);
  240. return;
  241. }
  242. // get the class instance pointer from the second input
  243. // every following function needs the classifier object
  244. NICE::GPHIKClassifier * classifier = MatlabConversion::convertMat2Ptr<NICE::GPHIKClassifier>(prhs[1]);
  245. ////////////////////////////////////////
  246. // Check which class method to call //
  247. ////////////////////////////////////////
  248. // standard train - assumes initialized object
  249. if (!strcmp("train", cmd.c_str() ))
  250. {
  251. // Check parameters
  252. if (nlhs < 0 || nrhs < 4)
  253. {
  254. mexErrMsgTxt("Train: Unexpected arguments.");
  255. }
  256. //------------- read the data --------------
  257. std::vector< const NICE::SparseVector *> examplesTrain;
  258. NICE::Vector yMultiTrain;
  259. if ( mxIsSparse( prhs[2] ) )
  260. {
  261. examplesTrain = MatlabConversion::convertSparseMatrixToNice( prhs[2] );
  262. }
  263. else
  264. {
  265. NICE::Matrix dataTrain;
  266. dataTrain = MatlabConversion::convertDoubleMatrixToNice(prhs[2]);
  267. //----------------- convert data to sparse data structures ---------
  268. examplesTrain.resize( dataTrain.rows() );
  269. std::vector< const NICE::SparseVector *>::iterator exTrainIt = examplesTrain.begin();
  270. for (int i = 0; i < (int)dataTrain.rows(); i++, exTrainIt++)
  271. {
  272. *exTrainIt = new NICE::SparseVector( dataTrain.getRow(i) );
  273. }
  274. }
  275. yMultiTrain = MatlabConversion::convertDoubleVectorToNice(prhs[3]);
  276. //----------------- train our classifier -------------
  277. classifier->train ( examplesTrain , yMultiTrain );
  278. //----------------- clean up -------------
  279. for(int i=0;i<examplesTrain.size();i++)
  280. delete examplesTrain[i];
  281. return;
  282. }
  283. // Classify
  284. if ( !strcmp("classify", cmd.c_str() ) )
  285. {
  286. // Check parameters
  287. if ( (nlhs < 0) || (nrhs < 2) )
  288. {
  289. mexErrMsgTxt("Test: Unexpected arguments.");
  290. }
  291. //------------- read the data --------------
  292. int result;
  293. NICE::SparseVector scores;
  294. double uncertainty;
  295. if ( mxIsSparse( prhs[2] ) )
  296. {
  297. NICE::SparseVector * example;
  298. example = new NICE::SparseVector ( MatlabConversion::convertSparseVectorToNice( prhs[2] ) );
  299. classifier->classify ( example, result, scores, uncertainty );
  300. //----------------- clean up -------------
  301. delete example;
  302. }
  303. else
  304. {
  305. NICE::Vector * example;
  306. example = new NICE::Vector ( MatlabConversion::convertDoubleVectorToNice(prhs[2]) );
  307. classifier->classify ( example, result, scores, uncertainty );
  308. //----------------- clean up -------------
  309. delete example;
  310. }
  311. // output
  312. plhs[0] = mxCreateDoubleScalar( result );
  313. if(nlhs >= 2)
  314. {
  315. plhs[1] = MatlabConversion::convertSparseVectorFromNice( scores, true /*b_adaptIndex*/);
  316. }
  317. if(nlhs >= 3)
  318. {
  319. plhs[2] = mxCreateDoubleScalar( uncertainty );
  320. }
  321. return;
  322. }
  323. // Uncertainty prediction
  324. if ( !strcmp("uncertainty", cmd.c_str() ) )
  325. {
  326. // Check parameters
  327. if ( (nlhs < 0) || (nrhs < 2) )
  328. {
  329. mexErrMsgTxt("Test: Unexpected arguments.");
  330. }
  331. double uncertainty;
  332. //------------- read the data --------------
  333. if ( mxIsSparse( prhs[2] ) )
  334. {
  335. NICE::SparseVector * example;
  336. example = new NICE::SparseVector ( MatlabConversion::convertSparseVectorToNice( prhs[2] ) );
  337. classifier->predictUncertainty( example, uncertainty );
  338. //----------------- clean up -------------
  339. delete example;
  340. }
  341. else
  342. {
  343. NICE::Vector * example;
  344. example = new NICE::Vector ( MatlabConversion::convertDoubleVectorToNice(prhs[2]) );
  345. classifier->predictUncertainty( example, uncertainty );
  346. //----------------- clean up -------------
  347. delete example;
  348. }
  349. // output
  350. plhs[0] = mxCreateDoubleScalar( uncertainty );
  351. return;
  352. }
  353. // Test - evaluate classifier on whole test set
  354. if ( !strcmp("test", cmd.c_str() ) )
  355. {
  356. // Check parameters
  357. if (nlhs < 0 || nrhs < 4)
  358. mexErrMsgTxt("Test: Unexpected arguments.");
  359. //------------- read the data --------------
  360. bool dataIsSparse ( mxIsSparse( prhs[2] ) );
  361. std::vector< const NICE::SparseVector *> dataTest_sparse;
  362. NICE::Matrix dataTest_dense;
  363. if ( dataIsSparse )
  364. {
  365. dataTest_sparse = MatlabConversion::convertSparseMatrixToNice( prhs[2] );
  366. }
  367. else
  368. {
  369. dataTest_dense = MatlabConversion::convertDoubleMatrixToNice(prhs[2]);
  370. }
  371. NICE::Vector yMultiTest;
  372. yMultiTest = MatlabConversion::convertDoubleVectorToNice(prhs[3]);
  373. // ------------------------------------------
  374. // ------------- PREPARATION --------------
  375. // ------------------------------------------
  376. // determine classes known during training and corresponding mapping
  377. // thereby allow for non-continous class labels
  378. std::set<int> classesKnownTraining = classifier->getKnownClassNumbers();
  379. int noClassesKnownTraining ( classesKnownTraining.size() );
  380. std::map<int,int> mapClNoToIdxTrain;
  381. std::set<int>::const_iterator clTrIt = classesKnownTraining.begin();
  382. for ( int i=0; i < noClassesKnownTraining; i++, clTrIt++ )
  383. mapClNoToIdxTrain.insert ( std::pair<int,int> ( *clTrIt, i ) );
  384. // determine classes known during testing and corresponding mapping
  385. // thereby allow for non-continous class labels
  386. std::set<int> classesKnownTest;
  387. classesKnownTest.clear();
  388. // determine which classes we have in our label vector
  389. // -> MATLAB: myClasses = unique(y);
  390. for ( NICE::Vector::const_iterator it = yMultiTest.begin(); it != yMultiTest.end(); it++ )
  391. {
  392. if ( classesKnownTest.find ( *it ) == classesKnownTest.end() )
  393. {
  394. classesKnownTest.insert ( *it );
  395. }
  396. }
  397. int noClassesKnownTest ( classesKnownTest.size() );
  398. std::map<int,int> mapClNoToIdxTest;
  399. std::set<int>::const_iterator clTestIt = classesKnownTest.begin();
  400. for ( int i=0; i < noClassesKnownTest; i++, clTestIt++ )
  401. mapClNoToIdxTest.insert ( std::pair<int,int> ( *clTestIt, i ) );
  402. int i_numTestSamples;
  403. if ( dataIsSparse )
  404. i_numTestSamples = dataTest_sparse.size();
  405. else
  406. i_numTestSamples = (int) dataTest_dense.rows();
  407. NICE::Matrix confusionMatrix( noClassesKnownTraining, noClassesKnownTest, 0.0);
  408. NICE::Matrix scores( i_numTestSamples, noClassesKnownTraining, 0.0);
  409. // ------------------------------------------
  410. // ------------- CLASSIFICATION --------------
  411. // ------------------------------------------
  412. NICE::Timer t;
  413. double testTime (0.0);
  414. for (int i = 0; i < i_numTestSamples; i++)
  415. {
  416. //----------------- convert data to sparse data structures ---------
  417. int result;
  418. NICE::SparseVector exampleScoresSparse;
  419. if ( dataIsSparse )
  420. {
  421. // and classify
  422. t.start();
  423. classifier->classify( dataTest_sparse[ i ], result, exampleScoresSparse );
  424. t.stop();
  425. testTime += t.getLast();
  426. }
  427. else
  428. {
  429. NICE::Vector example ( dataTest_dense.getRow(i) );
  430. // and classify
  431. t.start();
  432. classifier->classify( &example, result, exampleScoresSparse );
  433. t.stop();
  434. testTime += t.getLast();
  435. }
  436. confusionMatrix( mapClNoToIdxTrain.find(result)->second, mapClNoToIdxTest.find(yMultiTest[i])->second ) += 1.0;
  437. int scoreCnt ( 0 );
  438. for ( NICE::SparseVector::const_iterator scoreIt = exampleScoresSparse.begin(); scoreIt != exampleScoresSparse.end(); scoreIt++, scoreCnt++ )
  439. {
  440. scores(i,scoreCnt) = scoreIt->second;
  441. }
  442. }
  443. std::cerr << "Time for testing: " << testTime << std::endl;
  444. // clean up
  445. if ( dataIsSparse )
  446. {
  447. for ( std::vector<const NICE::SparseVector *>::iterator it = dataTest_sparse.begin(); it != dataTest_sparse.end(); it++)
  448. delete *it;
  449. }
  450. confusionMatrix.normalizeColumnsL1();
  451. double recRate = confusionMatrix.trace()/confusionMatrix.cols();
  452. plhs[0] = mxCreateDoubleScalar( recRate );
  453. if(nlhs >= 2)
  454. plhs[1] = MatlabConversion::convertMatrixFromNice(confusionMatrix);
  455. if(nlhs >= 3)
  456. plhs[2] = MatlabConversion::convertMatrixFromNice(scores);
  457. return;
  458. }
  459. ///////////////////// INTERFACE ONLINE LEARNABLE /////////////////////
  460. // interface specific methods for incremental extensions
  461. ///////////////////// INTERFACE ONLINE LEARNABLE /////////////////////
  462. // addExample
  463. if ( !strcmp("addExample", cmd.c_str() ) )
  464. {
  465. // Check parameters
  466. if ( (nlhs < 0) || (nrhs < 4) )
  467. {
  468. mexErrMsgTxt("Test: Unexpected arguments.");
  469. }
  470. //------------- read the data --------------
  471. NICE::SparseVector * newExample;
  472. double newLabel;
  473. if ( mxIsSparse( prhs[2] ) )
  474. {
  475. newExample = new NICE::SparseVector ( MatlabConversion::convertSparseVectorToNice( prhs[2] ) );
  476. }
  477. else
  478. {
  479. NICE::Vector * example;
  480. example = new NICE::Vector ( MatlabConversion::convertDoubleVectorToNice(prhs[2]) );
  481. newExample = new NICE::SparseVector ( *example );
  482. //----------------- clean up -------------
  483. delete example;
  484. }
  485. newLabel = MatlabConversion::convertMatlabToDouble( prhs[3] );
  486. // setting performOptimizationAfterIncrement is optional
  487. if ( nrhs > 4 )
  488. {
  489. bool performOptimizationAfterIncrement;
  490. performOptimizationAfterIncrement = MatlabConversion::convertMatlabToBool( prhs[4] );
  491. classifier->addExample ( newExample, newLabel, performOptimizationAfterIncrement );
  492. }
  493. else
  494. {
  495. classifier->addExample ( newExample, newLabel );
  496. }
  497. //----------------- clean up -------------
  498. delete newExample;
  499. return;
  500. }
  501. // addMultipleExamples
  502. if ( !strcmp("addMultipleExamples", cmd.c_str() ) )
  503. {
  504. // Check parameters
  505. if ( (nlhs < 0) || (nrhs < 4) )
  506. {
  507. mexErrMsgTxt("Test: Unexpected arguments.");
  508. }
  509. //------------- read the data --------------
  510. std::vector< const NICE::SparseVector *> newExamples;
  511. NICE::Vector newLabels;
  512. if ( mxIsSparse( prhs[2] ) )
  513. {
  514. newExamples = MatlabConversion::convertSparseMatrixToNice( prhs[2] );
  515. }
  516. else
  517. {
  518. NICE::Matrix newData;
  519. newData = MatlabConversion::convertDoubleMatrixToNice(prhs[2]);
  520. //----------------- convert data to sparse data structures ---------
  521. newExamples.resize( newData.rows() );
  522. std::vector< const NICE::SparseVector *>::iterator exTrainIt = newExamples.begin();
  523. for (int i = 0; i < (int)newData.rows(); i++, exTrainIt++)
  524. {
  525. *exTrainIt = new NICE::SparseVector( newData.getRow(i) );
  526. }
  527. }
  528. newLabels = MatlabConversion::convertDoubleVectorToNice(prhs[3]);
  529. // setting performOptimizationAfterIncrement is optional
  530. if ( nrhs > 4 )
  531. {
  532. bool performOptimizationAfterIncrement;
  533. performOptimizationAfterIncrement = MatlabConversion::convertMatlabToBool( prhs[4] );
  534. classifier->addMultipleExamples ( newExamples, newLabels, performOptimizationAfterIncrement );
  535. }
  536. else
  537. {
  538. classifier->addMultipleExamples ( newExamples, newLabels );
  539. }
  540. //----------------- clean up -------------
  541. for ( std::vector< const NICE::SparseVector *>::iterator exIt = newExamples.begin();
  542. exIt != newExamples.end(); exIt++
  543. )
  544. {
  545. delete *exIt;
  546. }
  547. return;
  548. }
  549. ///////////////////// INTERFACE PERSISTENT /////////////////////
  550. // interface specific methods for store and restore
  551. ///////////////////// INTERFACE PERSISTENT /////////////////////
  552. // store the classifier to an external file
  553. if ( !strcmp("store", cmd.c_str() ) || !strcmp("save", cmd.c_str() ) )
  554. {
  555. // Check parameters
  556. if ( nrhs < 3 )
  557. mexErrMsgTxt("store: no destination given.");
  558. std::string s_destination = MatlabConversion::convertMatlabToString( prhs[2] );
  559. std::filebuf fb;
  560. fb.open ( s_destination.c_str(), ios::out );
  561. std::ostream os(&fb);
  562. //
  563. classifier->store( os );
  564. //
  565. fb.close();
  566. return;
  567. }
  568. // load classifier from external file
  569. if ( !strcmp("restore", cmd.c_str() ) || !strcmp("load", cmd.c_str() ) )
  570. {
  571. // Check parameters
  572. if ( nrhs < 3 )
  573. mexErrMsgTxt("restore: no destination given.");
  574. std::string s_destination = MatlabConversion::convertMatlabToString( prhs[2] );
  575. std::cerr << " aim at restoring the classifier from " << s_destination << std::endl;
  576. std::filebuf fbIn;
  577. fbIn.open ( s_destination.c_str(), ios::in );
  578. std::istream is (&fbIn);
  579. //
  580. classifier->restore( is );
  581. //
  582. fbIn.close();
  583. return;
  584. }
  585. // Got here, so command not recognized
  586. std::string errorMsg (cmd.c_str() );
  587. errorMsg += " -- command not recognized.";
  588. mexErrMsgTxt( errorMsg.c_str() );
  589. }
  590. #endif