GPHIK.cpp 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997
  1. /**
  2. * @file GPHIK.cpp
  3. * @author Alexander Freytag
  4. * @date 07-01-2014 (dd-m-yyyy)
  5. * @brief Matlab-Interface of our GPHIKClassifier, allowing for training, classification, optimization, variance prediction, incremental learning, and storing/re-storing.
  6. */
  7. // STL includes
  8. #include <math.h>
  9. #include <matrix.h>
  10. #include <mex.h>
  11. // NICE-core includes
  12. #include <core/basics/Config.h>
  13. #include <core/basics/Timer.h>
  14. #include <core/vector/MatrixT.h>
  15. #include <core/vector/VectorT.h>
  16. // gp-hik-core includes
  17. #include "gp-hik-core/GPHIKClassifier.h"
  18. // Interface for conversion between Matlab and C objects
  19. #include "classHandleMtoC.h"
  20. using namespace std; //C basics
  21. using namespace NICE; // nice-core
  22. /* A sparse mxArray only stores its nonzero elements.
  23. * The values of the nonzero elements are stored in
  24. * the pr and pi arrays. The tricky part of analyzing
  25. * sparse mxArray's is figuring out the indices where
  26. * the nonzero elements are stored.
  27. * (See the mxSetIr and mxSetJc reference pages for details. */
  28. std::vector< const NICE::SparseVector * > convertSparseMatrixToNice(const mxArray *array_ptr)
  29. {
  30. double *pr;//, *pi;
  31. mwIndex *ir, *jc;
  32. mwSize col, total=0;
  33. mwIndex starting_row_index, stopping_row_index, current_row_index;
  34. mwSize i_numExamples, i_numDim;
  35. /* Get the starting positions of all four data arrays. */
  36. pr = mxGetPr( array_ptr );
  37. // no complex data supported here
  38. // pi = mxGetPi(array_ptr);
  39. ir = mxGetIr( array_ptr );
  40. jc = mxGetJc( array_ptr );
  41. // dimenions of the matrix -> feature dimension and number of examples
  42. i_numExamples = mxGetM( array_ptr );
  43. i_numDim = mxGetN( array_ptr );
  44. // initialize output variable -- don't use const pointers here since the content of the vectors will change
  45. // in the following loop. We reinterprete the vector lateron into a const version
  46. std::vector< NICE::SparseVector * > sparseMatrix;
  47. sparseMatrix.resize ( i_numExamples );
  48. for ( std::vector< NICE::SparseVector * >::iterator matIt = sparseMatrix.begin();
  49. matIt != sparseMatrix.end(); matIt++)
  50. {
  51. *matIt = new NICE::SparseVector( i_numDim );
  52. }
  53. // now copy the data
  54. for ( col = 0; col < i_numDim; col++ )
  55. {
  56. starting_row_index = jc[col];
  57. stopping_row_index = jc[col+1];
  58. // empty column?
  59. if (starting_row_index == stopping_row_index)
  60. continue;
  61. else
  62. {
  63. for ( current_row_index = starting_row_index;
  64. current_row_index < stopping_row_index;
  65. current_row_index++
  66. )
  67. {
  68. //note: no complex data supported her
  69. sparseMatrix[ ir[current_row_index] ]->insert( std::pair<int, double>( col, pr[total++] ) );
  70. } // for-loop
  71. }
  72. } // for-loop over columns
  73. //NOTE
  74. // Compiler doesn't know how to automatically convert
  75. // std::vector<T*> to std::vector<T const*> because the way
  76. // the template system works means that in theory the two may
  77. // be specialised differently. This is an explicit conversion.
  78. return reinterpret_cast< std::vector< const NICE::SparseVector *> &>( sparseMatrix );
  79. }
  80. // b_adaptIndexMtoC: if true, dim k will be inserted as k, not as k-1 (which would be the default for M->C)
  81. NICE::SparseVector convertSparseVectorToNice(const mxArray* array_ptr, const bool & b_adaptIndexMtoC = false )
  82. {
  83. double *pr, *pi;
  84. mwIndex *ir, *jc;
  85. mwSize col, total=0;
  86. mwIndex starting_row_index, stopping_row_index, current_row_index;
  87. mwSize dimy, dimx;
  88. /* Get the starting positions of all four data arrays. */
  89. pr = mxGetPr( array_ptr );
  90. pi = mxGetPi( array_ptr );
  91. ir = mxGetIr( array_ptr );
  92. jc = mxGetJc( array_ptr );
  93. // dimenions of the matrix -> feature dimension and number of examples
  94. dimy = mxGetM( array_ptr );
  95. dimx = mxGetN( array_ptr );
  96. double* ptr = mxGetPr( array_ptr );
  97. if( (dimx != 1) && (dimy != 1) )
  98. mexErrMsgIdAndTxt("mexnice:error","Vector expected");
  99. NICE::SparseVector svec( std::max(dimx, dimy) );
  100. if ( dimx > 1)
  101. {
  102. for ( mwSize row=0; row < dimx; row++)
  103. {
  104. // empty column?
  105. if (jc[row] == jc[row+1])
  106. {
  107. continue;
  108. }
  109. else
  110. {
  111. //note: no complex data supported her
  112. double value ( pr[total++] );
  113. if ( b_adaptIndexMtoC )
  114. svec.insert( std::pair<int, double>( row+1, value ) );
  115. else
  116. svec.insert( std::pair<int, double>( row, value ) );
  117. }
  118. } // for loop over cols
  119. }
  120. else
  121. {
  122. mwSize numNonZero = jc[1]-jc[0];
  123. for ( mwSize colNonZero=0; colNonZero < numNonZero; colNonZero++)
  124. {
  125. //note: no complex data supported her
  126. double value ( pr[total++] );
  127. if ( b_adaptIndexMtoC )
  128. svec.insert( std::pair<int, double>( ir[colNonZero]+1, value ) );
  129. else
  130. svec.insert( std::pair<int, double>( ir[colNonZero], value ) );
  131. }
  132. }
  133. return svec;
  134. }
  135. // b_adaptIndexCtoM: if true, dim k will be inserted as k, not as k+1 (which would be the default for C->M)
  136. mxArray* convertSparseVectorFromNice( const NICE::SparseVector & scores, const bool & b_adaptIndexCtoM = false)
  137. {
  138. mxArray * matlabSparseVec = mxCreateSparse( scores.getDim() /*m*/, 1/*n*/, scores.size()/*nzmax*/, mxREAL);
  139. // To make the returned sparse mxArray useful, you must initialize the pr, ir, jc, and (if it exists) pi arrays.
  140. // mxCreateSparse allocates space for:
  141. //
  142. // A pr array of length nzmax.
  143. // A pi array of length nzmax, but only if ComplexFlag is mxCOMPLEX in C (1 in Fortran).
  144. // An ir array of length nzmax.
  145. // A jc array of length n+1.
  146. double* prPtr = mxGetPr(matlabSparseVec);
  147. mwIndex * ir = mxGetIr( matlabSparseVec );
  148. mwIndex * jc = mxGetJc( matlabSparseVec );
  149. jc[1] = scores.size(); jc[0] = 0;
  150. mwSize cnt = 0;
  151. for ( NICE::SparseVector::const_iterator myIt = scores.begin(); myIt != scores.end(); myIt++, cnt++ )
  152. {
  153. // set index
  154. if ( b_adaptIndexCtoM )
  155. ir[cnt] = myIt->first-1;
  156. else
  157. ir[cnt] = myIt->first;
  158. // set value
  159. prPtr[cnt] = myIt->second;
  160. }
  161. return matlabSparseVec;
  162. }
  163. mxArray* convertMatrixFromNice(NICE::Matrix & niceMatrix)
  164. {
  165. mxArray *matlabMatrix = mxCreateDoubleMatrix( niceMatrix.rows(), niceMatrix.cols(), mxREAL );
  166. double* matlabMatrixPtr = mxGetPr( matlabMatrix );
  167. for( int i = 0; i < niceMatrix.rows(); i++ )
  168. {
  169. for( int j = 0; j < niceMatrix.cols(); j++ )
  170. {
  171. matlabMatrixPtr[i + j*niceMatrix.rows()] = niceMatrix(i,j);
  172. }
  173. }
  174. return matlabMatrix;
  175. }
  176. NICE::Matrix convertDoubleMatrixToNice(const mxArray* matlabMatrix)
  177. {
  178. if( !mxIsDouble( matlabMatrix ) )
  179. mexErrMsgIdAndTxt( "mexnice:error","Expected double in convertDoubleMatrixToNice" );
  180. const mwSize *dims;
  181. int dimx, dimy, numdims;
  182. //figure out dimensions
  183. dims = mxGetDimensions( matlabMatrix );
  184. numdims = mxGetNumberOfDimensions( matlabMatrix );
  185. dimy = (int)dims[0];
  186. dimx = (int)dims[1];
  187. double* ptr = mxGetPr( matlabMatrix );
  188. NICE::Matrix niceMatrix(ptr, dimy, dimx, NICE::Matrix::external);
  189. return niceMatrix;
  190. }
  191. mxArray* convertVectorFromNice(NICE::Vector & niceVector)
  192. {
  193. mxArray *matlabVector = mxCreateDoubleMatrix( niceVector.size(), 1, mxREAL );
  194. double* matlabVectorPtr = mxGetPr( matlabVector );
  195. for( int i = 0; i < niceVector.size(); i++ )
  196. {
  197. matlabVectorPtr[i] = niceVector[i];
  198. }
  199. return matlabVector;
  200. }
  201. NICE::Vector convertDoubleVectorToNice( const mxArray* matlabMatrix )
  202. {
  203. if( !mxIsDouble( matlabMatrix ) )
  204. mexErrMsgIdAndTxt( "mexnice:error","Expected double in convertDoubleVectorToNice" );
  205. const mwSize *dims;
  206. int dimx, dimy, numdims;
  207. //figure out dimensions
  208. dims = mxGetDimensions( matlabMatrix );
  209. numdims = mxGetNumberOfDimensions( matlabMatrix );
  210. dimy = (int)dims[0];
  211. dimx = (int)dims[1];
  212. double* ptr = mxGetPr( matlabMatrix );
  213. if( (dimx != 1) && (dimy != 1) )
  214. mexErrMsgIdAndTxt("mexnice:error","Vector expected");
  215. int dim = std::max(dimx, dimy);
  216. NICE::Vector niceVector( dim, 0.0 );
  217. for( int i = 0; i < dim; i++ )
  218. {
  219. niceVector(i) = ptr[i];
  220. }
  221. return niceVector;
  222. }
  223. std::string convertMatlabToString( const mxArray *matlabString )
  224. {
  225. if( !mxIsChar( matlabString ) )
  226. mexErrMsgIdAndTxt("mexnice:error","Expected string");
  227. char *cstring = mxArrayToString( matlabString );
  228. std::string s( cstring );
  229. mxFree(cstring);
  230. return s;
  231. }
  232. int convertMatlabToInt32( const mxArray *matlabInt32 )
  233. {
  234. if( !mxIsInt32( matlabInt32 ) )
  235. mexErrMsgIdAndTxt("mexnice:error","Expected int32");
  236. int* ptr = (int*) mxGetData( matlabInt32 );
  237. return ptr[0];
  238. }
  239. double convertMatlabToDouble( const mxArray *matlabDouble )
  240. {
  241. if( !mxIsDouble(matlabDouble) )
  242. mexErrMsgIdAndTxt("mexnice:error","Expected double");
  243. double* ptr = (double*) mxGetData( matlabDouble );
  244. return ptr[0];
  245. }
  246. bool convertMatlabToBool(const mxArray *matlabBool)
  247. {
  248. if( !mxIsLogical( matlabBool ) )
  249. mexErrMsgIdAndTxt("mexnice:error","Expected bool");
  250. bool* ptr = (bool*) mxGetData( matlabBool );
  251. return ptr[0];
  252. }
  253. NICE::Config parseParameters(const mxArray *prhs[], int nrhs)
  254. {
  255. NICE::Config conf;
  256. // if first argument is the filename of an existing config file,
  257. // read the config accordingly
  258. int i_start ( 0 );
  259. std::string variable = convertMatlabToString(prhs[i_start]);
  260. if(variable == "conf")
  261. {
  262. conf = NICE::Config ( convertMatlabToString( prhs[i_start+1] ) );
  263. i_start = i_start+2;
  264. }
  265. // now run over all given parameter specifications
  266. // and add them to the config
  267. for( int i=i_start; i < nrhs; i+=2 )
  268. {
  269. std::string variable = convertMatlabToString(prhs[i]);
  270. if(variable == "ils_verbose")
  271. {
  272. string value = convertMatlabToString(prhs[i+1]);
  273. if(value != "true" && value != "false")
  274. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'ils_verbose\'. \'true\' or \'false\' expected.");
  275. if(value == "true")
  276. conf.sB("GPHIKClassifier", variable, true);
  277. else
  278. conf.sB("GPHIKClassifier", variable, false);
  279. }
  280. if(variable == "ils_max_iterations")
  281. {
  282. int value = convertMatlabToInt32(prhs[i+1]);
  283. if(value < 1)
  284. mexErrMsgIdAndTxt("mexnice:error","Expected parameter value larger than 0 for \'ils_max_iterations\'.");
  285. conf.sI("GPHIKClassifier", variable, value);
  286. }
  287. if(variable == "ils_method")
  288. {
  289. string value = convertMatlabToString(prhs[i+1]);
  290. if(value != "CG" && value != "CGL" && value != "SYMMLQ" && value != "MINRES")
  291. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'ils_method\'. \'CG\', \'CGL\', \'SYMMLQ\' or \'MINRES\' expected.");
  292. conf.sS("GPHIKClassifier", variable, value);
  293. }
  294. if(variable == "ils_min_delta")
  295. {
  296. double value = convertMatlabToDouble(prhs[i+1]);
  297. if(value < 0.0)
  298. mexErrMsgIdAndTxt("mexnice:error","Expected parameter value larger than 0 for \'ils_min_delta\'.");
  299. conf.sD("GPHIKClassifier", variable, value);
  300. }
  301. if(variable == "ils_min_residual")
  302. {
  303. double value = convertMatlabToDouble(prhs[i+1]);
  304. if(value < 0.0)
  305. mexErrMsgIdAndTxt("mexnice:error","Expected parameter value larger than 0 for \'ils_min_residual\'.");
  306. conf.sD("GPHIKClassifier", variable, value);
  307. }
  308. if(variable == "optimization_method")
  309. {
  310. string value = convertMatlabToString(prhs[i+1]);
  311. if(value != "greedy" && value != "downhillsimplex" && value != "none")
  312. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'optimization_method\'. \'greedy\', \'downhillsimplex\' or \'none\' expected.");
  313. conf.sS("GPHIKClassifier", variable, value);
  314. }
  315. if(variable == "use_quantization")
  316. {
  317. string value = convertMatlabToString(prhs[i+1]);
  318. if(value != "true" && value != "false")
  319. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'use_quantization\'. \'true\' or \'false\' expected.");
  320. if(value == "true")
  321. conf.sB("GPHIKClassifier", variable, true);
  322. else
  323. conf.sB("GPHIKClassifier", variable, false);
  324. }
  325. if(variable == "num_bins")
  326. {
  327. int value = convertMatlabToInt32(prhs[i+1]);
  328. if(value < 1)
  329. mexErrMsgIdAndTxt("mexnice:error","Expected parameter value larger than 0 for \'num_bins\'.");
  330. conf.sI("GPHIKClassifier", variable, value);
  331. }
  332. if(variable == "transform")
  333. {
  334. string value = convertMatlabToString(prhs[i+1]);
  335. if(value != "absexp" && value != "exp" && value != "MKL" && value != "WeightedDim")
  336. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'transform\'. \'absexp\', \'exp\' , \'MKL\' or \'WeightedDim\' expected.");
  337. conf.sS("GPHIKClassifier", variable, value);
  338. }
  339. if(variable == "verboseTime")
  340. {
  341. string value = convertMatlabToString(prhs[i+1]);
  342. if(value != "true" && value != "false")
  343. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'verboseTime\'. \'true\' or \'false\' expected.");
  344. if(value == "true")
  345. conf.sB("GPHIKClassifier", variable, true);
  346. else
  347. conf.sB("GPHIKClassifier", variable, false);
  348. }
  349. if(variable == "verbose")
  350. {
  351. string value = convertMatlabToString(prhs[i+1]);
  352. if(value != "true" && value != "false")
  353. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'verbose\'. \'true\' or \'false\' expected.");
  354. if(value == "true")
  355. conf.sB("GPHIKClassifier", variable, true);
  356. else
  357. conf.sB("GPHIKClassifier", variable, false);
  358. }
  359. if(variable == "noise")
  360. {
  361. double value = convertMatlabToDouble(prhs[i+1]);
  362. if(value < 0.0)
  363. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value larger than 0 for \'noise\'.");
  364. conf.sD("GPHIKClassifier", variable, value);
  365. }
  366. if(variable == "optimize_noise")
  367. {
  368. string value = convertMatlabToString(prhs[i+1]);
  369. if(value != "true" && value != "false")
  370. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'optimize_noise\'. \'true\' or \'false\' expected.");
  371. if(value == "true")
  372. conf.sB("GPHIKClassifier", variable, true);
  373. else
  374. conf.sB("GPHIKClassifier", variable, false);
  375. }
  376. if(variable == "varianceApproximation")
  377. {
  378. string value = convertMatlabToString(prhs[i+1]);
  379. if(value != "approximate_fine" && value != "approximate_rough" && value != "exact" && value != "none")
  380. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'varianceApproximation\'. \'approximate_fine\', \'approximate_rough\', \'none\' or \'exact\' expected.");
  381. conf.sS("GPHIKClassifier", variable, value);
  382. }
  383. if(variable == "nrOfEigenvaluesToConsiderForVarApprox")
  384. {
  385. double value = convertMatlabToDouble(prhs[i+1]);
  386. conf.sI("GPHIKClassifier", variable, (int) value);
  387. }
  388. }
  389. return conf;
  390. }
  391. // MAIN MATLAB FUNCTION
  392. void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
  393. {
  394. // get the command string specifying what to do
  395. if (nrhs < 1)
  396. mexErrMsgTxt("No commands and options passed... Aborting!");
  397. if( !mxIsChar( prhs[0] ) )
  398. mexErrMsgTxt("First argument needs to be the command, ie.e, the class method to call... Aborting!");
  399. std::string cmd = convertMatlabToString( prhs[0] );
  400. // create object
  401. if ( !strcmp("new", cmd.c_str() ) )
  402. {
  403. // check output variable
  404. if (nlhs != 1)
  405. mexErrMsgTxt("New: One output expected.");
  406. // read config settings
  407. NICE::Config conf = parseParameters(prhs+1,nrhs-1);
  408. // create class instance
  409. NICE::GPHIKClassifier * classifier = new NICE::GPHIKClassifier ( &conf );
  410. // handle to the C++ instance
  411. plhs[0] = convertPtr2Mat<NICE::GPHIKClassifier>( classifier );
  412. return;
  413. }
  414. // in all other cases, there should be a second input,
  415. // which the be the class instance handle
  416. if (nrhs < 2)
  417. mexErrMsgTxt("Second input should be a class instance handle.");
  418. // delete object
  419. if ( !strcmp("delete", cmd.c_str() ) )
  420. {
  421. // Destroy the C++ object
  422. destroyObject<NICE::GPHIKClassifier>(prhs[1]);
  423. return;
  424. }
  425. // get the class instance pointer from the second input
  426. // every following function needs the classifier object
  427. NICE::GPHIKClassifier * classifier = convertMat2Ptr<NICE::GPHIKClassifier>(prhs[1]);
  428. ////////////////////////////////////////
  429. // Check which class method to call //
  430. ////////////////////////////////////////
  431. // standard train - assumes initialized object
  432. if (!strcmp("train", cmd.c_str() ))
  433. {
  434. // Check parameters
  435. if (nlhs < 0 || nrhs < 4)
  436. {
  437. mexErrMsgTxt("Train: Unexpected arguments.");
  438. }
  439. //------------- read the data --------------
  440. std::vector< const NICE::SparseVector *> examplesTrain;
  441. NICE::Vector yMultiTrain;
  442. if ( mxIsSparse( prhs[2] ) )
  443. {
  444. examplesTrain = convertSparseMatrixToNice( prhs[2] );
  445. }
  446. else
  447. {
  448. NICE::Matrix dataTrain;
  449. dataTrain = convertDoubleMatrixToNice(prhs[2]);
  450. //----------------- convert data to sparse data structures ---------
  451. examplesTrain.resize( dataTrain.rows() );
  452. std::vector< const NICE::SparseVector *>::iterator exTrainIt = examplesTrain.begin();
  453. for (int i = 0; i < (int)dataTrain.rows(); i++, exTrainIt++)
  454. {
  455. *exTrainIt = new NICE::SparseVector( dataTrain.getRow(i) );
  456. }
  457. }
  458. yMultiTrain = convertDoubleVectorToNice(prhs[3]);
  459. //----------------- train our classifier -------------
  460. classifier->train ( examplesTrain , yMultiTrain );
  461. //----------------- clean up -------------
  462. for(int i=0;i<examplesTrain.size();i++)
  463. delete examplesTrain[i];
  464. return;
  465. }
  466. // Classify
  467. if ( !strcmp("classify", cmd.c_str() ) )
  468. {
  469. // Check parameters
  470. if ( (nlhs < 0) || (nrhs < 2) )
  471. {
  472. mexErrMsgTxt("Test: Unexpected arguments.");
  473. }
  474. //------------- read the data --------------
  475. int result;
  476. NICE::SparseVector scores;
  477. double uncertainty;
  478. if ( mxIsSparse( prhs[2] ) )
  479. {
  480. NICE::SparseVector * example;
  481. example = new NICE::SparseVector ( convertSparseVectorToNice( prhs[2] ) );
  482. classifier->classify ( example, result, scores, uncertainty );
  483. //----------------- clean up -------------
  484. delete example;
  485. }
  486. else
  487. {
  488. NICE::Vector * example;
  489. example = new NICE::Vector ( convertDoubleVectorToNice(prhs[2]) );
  490. classifier->classify ( example, result, scores, uncertainty );
  491. //----------------- clean up -------------
  492. delete example;
  493. }
  494. // output
  495. plhs[0] = mxCreateDoubleScalar( result );
  496. if(nlhs >= 2)
  497. {
  498. plhs[1] = convertSparseVectorFromNice( scores, true /*b_adaptIndex*/);
  499. }
  500. if(nlhs >= 3)
  501. {
  502. plhs[2] = mxCreateDoubleScalar( uncertainty );
  503. }
  504. return;
  505. }
  506. // Uncertainty prediction
  507. if ( !strcmp("uncertainty", cmd.c_str() ) )
  508. {
  509. // Check parameters
  510. if ( (nlhs < 0) || (nrhs < 2) )
  511. {
  512. mexErrMsgTxt("Test: Unexpected arguments.");
  513. }
  514. double uncertainty;
  515. //------------- read the data --------------
  516. if ( mxIsSparse( prhs[2] ) )
  517. {
  518. NICE::SparseVector * example;
  519. example = new NICE::SparseVector ( convertSparseVectorToNice( prhs[2] ) );
  520. classifier->predictUncertainty( example, uncertainty );
  521. //----------------- clean up -------------
  522. delete example;
  523. }
  524. else
  525. {
  526. NICE::Vector * example;
  527. example = new NICE::Vector ( convertDoubleVectorToNice(prhs[2]) );
  528. classifier->predictUncertainty( example, uncertainty );
  529. //----------------- clean up -------------
  530. delete example;
  531. }
  532. // output
  533. plhs[0] = mxCreateDoubleScalar( uncertainty );
  534. return;
  535. }
  536. // Test
  537. if ( !strcmp("test", cmd.c_str() ) )
  538. {
  539. // Check parameters
  540. if (nlhs < 0 || nrhs < 4)
  541. mexErrMsgTxt("Test: Unexpected arguments.");
  542. //------------- read the data --------------
  543. bool dataIsSparse ( mxIsSparse( prhs[2] ) );
  544. std::vector< const NICE::SparseVector *> dataTest_sparse;
  545. NICE::Matrix dataTest_dense;
  546. if ( dataIsSparse )
  547. {
  548. dataTest_sparse = convertSparseMatrixToNice( prhs[2] );
  549. }
  550. else
  551. {
  552. dataTest_dense = convertDoubleMatrixToNice(prhs[2]);
  553. }
  554. NICE::Vector yMultiTest;
  555. yMultiTest = convertDoubleVectorToNice(prhs[3]);
  556. // ------------------------------------------
  557. // ------------- PREPARATION --------------
  558. // ------------------------------------------
  559. // determine classes known during training and corresponding mapping
  560. // thereby allow for non-continous class labels
  561. std::set<int> classesKnownTraining = classifier->getKnownClassNumbers();
  562. int noClassesKnownTraining ( classesKnownTraining.size() );
  563. std::map<int,int> mapClNoToIdxTrain;
  564. std::set<int>::const_iterator clTrIt = classesKnownTraining.begin();
  565. for ( int i=0; i < noClassesKnownTraining; i++, clTrIt++ )
  566. mapClNoToIdxTrain.insert ( std::pair<int,int> ( *clTrIt, i ) );
  567. // determine classes known during testing and corresponding mapping
  568. // thereby allow for non-continous class labels
  569. std::set<int> classesKnownTest;
  570. classesKnownTest.clear();
  571. // determine which classes we have in our label vector
  572. // -> MATLAB: myClasses = unique(y);
  573. for ( NICE::Vector::const_iterator it = yMultiTest.begin(); it != yMultiTest.end(); it++ )
  574. {
  575. if ( classesKnownTest.find ( *it ) == classesKnownTest.end() )
  576. {
  577. classesKnownTest.insert ( *it );
  578. }
  579. }
  580. int noClassesKnownTest ( classesKnownTest.size() );
  581. std::map<int,int> mapClNoToIdxTest;
  582. std::set<int>::const_iterator clTestIt = classesKnownTest.begin();
  583. for ( int i=0; i < noClassesKnownTest; i++, clTestIt++ )
  584. mapClNoToIdxTest.insert ( std::pair<int,int> ( *clTestIt, i ) );
  585. int i_numTestSamples;
  586. if ( dataIsSparse )
  587. i_numTestSamples = dataTest_sparse.size();
  588. else
  589. i_numTestSamples = (int) dataTest_dense.rows();
  590. NICE::Matrix confusionMatrix( noClassesKnownTraining, noClassesKnownTest, 0.0);
  591. NICE::Matrix scores( i_numTestSamples, noClassesKnownTraining, 0.0);
  592. // ------------------------------------------
  593. // ------------- CLASSIFICATION --------------
  594. // ------------------------------------------
  595. NICE::Timer t;
  596. double testTime (0.0);
  597. for (int i = 0; i < i_numTestSamples; i++)
  598. {
  599. //----------------- convert data to sparse data structures ---------
  600. int result;
  601. NICE::SparseVector exampleScoresSparse;
  602. if ( dataIsSparse )
  603. {
  604. // and classify
  605. t.start();
  606. classifier->classify( dataTest_sparse[ i ], result, exampleScoresSparse );
  607. t.stop();
  608. testTime += t.getLast();
  609. }
  610. else
  611. {
  612. NICE::Vector example ( dataTest_dense.getRow(i) );
  613. // and classify
  614. t.start();
  615. classifier->classify( &example, result, exampleScoresSparse );
  616. t.stop();
  617. testTime += t.getLast();
  618. }
  619. confusionMatrix( mapClNoToIdxTrain.find(result)->second, mapClNoToIdxTest.find(yMultiTest[i])->second ) += 1.0;
  620. int scoreCnt ( 0 );
  621. for ( NICE::SparseVector::const_iterator scoreIt = exampleScoresSparse.begin(); scoreIt != exampleScoresSparse.end(); scoreIt++, scoreCnt++ )
  622. {
  623. scores(i,scoreCnt) = scoreIt->second;
  624. }
  625. }
  626. std::cerr << "Time for testing: " << testTime << std::endl;
  627. // clean up
  628. if ( dataIsSparse )
  629. {
  630. for ( std::vector<const NICE::SparseVector *>::iterator it = dataTest_sparse.begin(); it != dataTest_sparse.end(); it++)
  631. delete *it;
  632. }
  633. confusionMatrix.normalizeColumnsL1();
  634. //std::cerr << confusionMatrix << std::endl;
  635. double recRate = confusionMatrix.trace()/confusionMatrix.rows();
  636. //std::cerr << "average recognition rate: " << recRate << std::endl;
  637. plhs[0] = mxCreateDoubleScalar( recRate );
  638. if(nlhs >= 2)
  639. plhs[1] = convertMatrixFromNice(confusionMatrix);
  640. if(nlhs >= 3)
  641. plhs[2] = convertMatrixFromNice(scores);
  642. return;
  643. }
  644. ///////////////////// INTERFACE ONLINE LEARNABLE /////////////////////
  645. // interface specific methods for incremental extensions
  646. ///////////////////// INTERFACE ONLINE LEARNABLE /////////////////////
  647. // addExample
  648. if ( !strcmp("addExample", cmd.c_str() ) )
  649. {
  650. // Check parameters
  651. if ( (nlhs < 0) || (nrhs < 4) )
  652. {
  653. mexErrMsgTxt("Test: Unexpected arguments.");
  654. }
  655. //------------- read the data --------------
  656. NICE::SparseVector * newExample;
  657. double newLabel;
  658. if ( mxIsSparse( prhs[2] ) )
  659. {
  660. newExample = new NICE::SparseVector ( convertSparseVectorToNice( prhs[2] ) );
  661. }
  662. else
  663. {
  664. NICE::Vector * example;
  665. example = new NICE::Vector ( convertDoubleVectorToNice(prhs[2]) );
  666. newExample = new NICE::SparseVector ( *example );
  667. //----------------- clean up -------------
  668. delete example;
  669. }
  670. newLabel = convertMatlabToDouble( prhs[3] );
  671. // setting performOptimizationAfterIncrement is optional
  672. if ( nrhs > 4 )
  673. {
  674. bool performOptimizationAfterIncrement;
  675. performOptimizationAfterIncrement = convertMatlabToBool( prhs[4] );
  676. classifier->addExample ( newExample, newLabel, performOptimizationAfterIncrement );
  677. }
  678. else
  679. {
  680. classifier->addExample ( newExample, newLabel );
  681. }
  682. //----------------- clean up -------------
  683. delete newExample;
  684. return;
  685. }
  686. // addExample
  687. if ( !strcmp("addMultipleExamples", cmd.c_str() ) )
  688. {
  689. // Check parameters
  690. if ( (nlhs < 0) || (nrhs < 4) )
  691. {
  692. mexErrMsgTxt("Test: Unexpected arguments.");
  693. }
  694. //------------- read the data --------------
  695. std::vector< const NICE::SparseVector *> newExamples;
  696. NICE::Vector newLabels;
  697. if ( mxIsSparse( prhs[2] ) )
  698. {
  699. newExamples = convertSparseMatrixToNice( prhs[2] );
  700. }
  701. else
  702. {
  703. NICE::Matrix newData;
  704. newData = convertDoubleMatrixToNice(prhs[2]);
  705. //----------------- convert data to sparse data structures ---------
  706. newExamples.resize( newData.rows() );
  707. std::vector< const NICE::SparseVector *>::iterator exTrainIt = newExamples.begin();
  708. for (int i = 0; i < (int)newData.rows(); i++, exTrainIt++)
  709. {
  710. *exTrainIt = new NICE::SparseVector( newData.getRow(i) );
  711. }
  712. }
  713. newLabels = convertDoubleVectorToNice(prhs[3]);
  714. // setting performOptimizationAfterIncrement is optional
  715. if ( nrhs > 4 )
  716. {
  717. bool performOptimizationAfterIncrement;
  718. performOptimizationAfterIncrement = convertMatlabToBool( prhs[4] );
  719. classifier->addMultipleExamples ( newExamples, newLabels, performOptimizationAfterIncrement );
  720. }
  721. else
  722. {
  723. classifier->addMultipleExamples ( newExamples, newLabels );
  724. }
  725. //----------------- clean up -------------
  726. for ( std::vector< const NICE::SparseVector *>::iterator exIt = newExamples.begin();
  727. exIt != newExamples.end(); exIt++
  728. )
  729. {
  730. delete *exIt;
  731. }
  732. return;
  733. }
  734. ///////////////////// INTERFACE PERSISTENT /////////////////////
  735. // interface specific methods for store and restore
  736. ///////////////////// INTERFACE PERSISTENT /////////////////////
  737. // store the classifier
  738. if ( !strcmp("store", cmd.c_str() ) || !strcmp("save", cmd.c_str() ) )
  739. {
  740. // Check parameters
  741. if ( nrhs < 3 )
  742. mexErrMsgTxt("store: no destination given.");
  743. std::string s_destination = convertMatlabToString( prhs[2] );
  744. std::filebuf fb;
  745. fb.open ( s_destination.c_str(), ios::out );
  746. std::ostream os(&fb);
  747. //
  748. classifier->store( os );
  749. //
  750. fb.close();
  751. return;
  752. }
  753. // load classifier from external file
  754. if ( !strcmp("restore", cmd.c_str() ) || !strcmp("load", cmd.c_str() ) )
  755. {
  756. // Check parameters
  757. if ( nrhs < 3 )
  758. mexErrMsgTxt("restore: no destination given.");
  759. std::string s_destination = convertMatlabToString( prhs[2] );
  760. std::cerr << " aim at restoring the classifier from " << s_destination << std::endl;
  761. std::filebuf fbIn;
  762. fbIn.open ( s_destination.c_str(), ios::in );
  763. std::istream is (&fbIn);
  764. //
  765. classifier->restore( is );
  766. //
  767. fbIn.close();
  768. return;
  769. }
  770. // Got here, so command not recognized
  771. std::string errorMsg (cmd.c_str() );
  772. errorMsg += " -- command not recognized.";
  773. mexErrMsgTxt( errorMsg.c_str() );
  774. }