GPHIK.cpp 32 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015
  1. /**
  2. * @file GPHIK.cpp
  3. * @author Alexander Freytag
  4. * @date 07-01-2014 (dd-mm-yyyy)
  5. * @brief Matlab-Interface of our GPHIKClassifier, allowing for training, classification, optimization, variance prediction, incremental learning, and storing/re-storing.
  6. */
  7. // STL includes
  8. #include <math.h>
  9. #include <matrix.h>
  10. #include <mex.h>
  11. // NICE-core includes
  12. #include <core/basics/Config.h>
  13. #include <core/basics/Timer.h>
  14. #include <core/vector/MatrixT.h>
  15. #include <core/vector/VectorT.h>
  16. // gp-hik-core includes
  17. #include "gp-hik-core/GPHIKClassifier.h"
  18. // Interface for conversion between Matlab and C objects
  19. #include "classHandleMtoC.h"
  20. using namespace std; //C basics
  21. using namespace NICE; // nice-core
  22. /* A sparse mxArray only stores its nonzero elements.
  23. * The values of the nonzero elements are stored in
  24. * the pr and pi arrays. The tricky part of analyzing
  25. * sparse mxArray's is figuring out the indices where
  26. * the nonzero elements are stored.
  27. * (See the mxSetIr and mxSetJc reference pages for details. */
  28. std::vector< const NICE::SparseVector * > convertSparseMatrixToNice(const mxArray *array_ptr)
  29. {
  30. double *pr;//, *pi;
  31. mwIndex *ir, *jc;
  32. mwSize col, total=0;
  33. mwIndex starting_row_index, stopping_row_index, current_row_index;
  34. mwSize i_numExamples, i_numDim;
  35. /* Get the starting positions of all four data arrays. */
  36. pr = mxGetPr( array_ptr );
  37. // no complex data supported here
  38. // pi = mxGetPi(array_ptr);
  39. ir = mxGetIr( array_ptr );
  40. jc = mxGetJc( array_ptr );
  41. // dimenions of the matrix -> feature dimension and number of examples
  42. i_numExamples = mxGetM( array_ptr );
  43. i_numDim = mxGetN( array_ptr );
  44. // initialize output variable -- don't use const pointers here since the content of the vectors will change
  45. // in the following loop. We reinterprete the vector lateron into a const version
  46. std::vector< NICE::SparseVector * > sparseMatrix;
  47. sparseMatrix.resize ( i_numExamples );
  48. for ( std::vector< NICE::SparseVector * >::iterator matIt = sparseMatrix.begin();
  49. matIt != sparseMatrix.end(); matIt++)
  50. {
  51. *matIt = new NICE::SparseVector( i_numDim );
  52. }
  53. // now copy the data
  54. for ( col = 0; col < i_numDim; col++ )
  55. {
  56. starting_row_index = jc[col];
  57. stopping_row_index = jc[col+1];
  58. // empty column?
  59. if (starting_row_index == stopping_row_index)
  60. continue;
  61. else
  62. {
  63. for ( current_row_index = starting_row_index;
  64. current_row_index < stopping_row_index;
  65. current_row_index++
  66. )
  67. {
  68. //note: no complex data supported her
  69. sparseMatrix[ ir[current_row_index] ]->insert( std::pair<int, double>( col, pr[total++] ) );
  70. } // for-loop
  71. }
  72. } // for-loop over columns
  73. //NOTE
  74. // Compiler doesn't know how to automatically convert
  75. // std::vector<T*> to std::vector<T const*> because the way
  76. // the template system works means that in theory the two may
  77. // be specialised differently. This is an explicit conversion.
  78. return reinterpret_cast< std::vector< const NICE::SparseVector *> &>( sparseMatrix );
  79. }
  80. // b_adaptIndexMtoC: if true, dim k will be inserted as k, not as k-1 (which would be the default for M->C)
  81. NICE::SparseVector convertSparseVectorToNice(const mxArray* array_ptr, const bool & b_adaptIndexMtoC = false )
  82. {
  83. double *pr, *pi;
  84. mwIndex *ir, *jc;
  85. mwSize col, total=0;
  86. mwIndex starting_row_index, stopping_row_index, current_row_index;
  87. mwSize dimy, dimx;
  88. /* Get the starting positions of all four data arrays. */
  89. pr = mxGetPr( array_ptr );
  90. pi = mxGetPi( array_ptr );
  91. ir = mxGetIr( array_ptr );
  92. jc = mxGetJc( array_ptr );
  93. // dimenions of the matrix -> feature dimension and number of examples
  94. dimy = mxGetM( array_ptr );
  95. dimx = mxGetN( array_ptr );
  96. double* ptr = mxGetPr( array_ptr );
  97. if( (dimx != 1) && (dimy != 1) )
  98. mexErrMsgIdAndTxt("mexnice:error","Vector expected");
  99. NICE::SparseVector svec( std::max(dimx, dimy) );
  100. if ( dimx > 1)
  101. {
  102. for ( mwSize row=0; row < dimx; row++)
  103. {
  104. // empty column?
  105. if (jc[row] == jc[row+1])
  106. {
  107. continue;
  108. }
  109. else
  110. {
  111. //note: no complex data supported her
  112. double value ( pr[total++] );
  113. if ( b_adaptIndexMtoC )
  114. svec.insert( std::pair<int, double>( row+1, value ) );
  115. else
  116. svec.insert( std::pair<int, double>( row, value ) );
  117. }
  118. } // for loop over cols
  119. }
  120. else
  121. {
  122. mwSize numNonZero = jc[1]-jc[0];
  123. for ( mwSize colNonZero=0; colNonZero < numNonZero; colNonZero++)
  124. {
  125. //note: no complex data supported her
  126. double value ( pr[total++] );
  127. if ( b_adaptIndexMtoC )
  128. svec.insert( std::pair<int, double>( ir[colNonZero]+1, value ) );
  129. else
  130. svec.insert( std::pair<int, double>( ir[colNonZero], value ) );
  131. }
  132. }
  133. return svec;
  134. }
  135. // b_adaptIndexCtoM: if true, dim k will be inserted as k, not as k+1 (which would be the default for C->M)
  136. mxArray* convertSparseVectorFromNice( const NICE::SparseVector & scores, const bool & b_adaptIndexCtoM = false)
  137. {
  138. mxArray * matlabSparseVec = mxCreateSparse( scores.getDim() /*m*/, 1/*n*/, scores.size()/*nzmax*/, mxREAL);
  139. // To make the returned sparse mxArray useful, you must initialize the pr, ir, jc, and (if it exists) pi arrays.
  140. // mxCreateSparse allocates space for:
  141. //
  142. // A pr array of length nzmax.
  143. // A pi array of length nzmax, but only if ComplexFlag is mxCOMPLEX in C (1 in Fortran).
  144. // An ir array of length nzmax.
  145. // A jc array of length n+1.
  146. double* prPtr = mxGetPr(matlabSparseVec);
  147. mwIndex * ir = mxGetIr( matlabSparseVec );
  148. mwIndex * jc = mxGetJc( matlabSparseVec );
  149. jc[1] = scores.size(); jc[0] = 0;
  150. mwSize cnt = 0;
  151. for ( NICE::SparseVector::const_iterator myIt = scores.begin(); myIt != scores.end(); myIt++, cnt++ )
  152. {
  153. // set index
  154. if ( b_adaptIndexCtoM )
  155. ir[cnt] = myIt->first-1;
  156. else
  157. ir[cnt] = myIt->first;
  158. // set value
  159. prPtr[cnt] = myIt->second;
  160. }
  161. return matlabSparseVec;
  162. }
  163. mxArray* convertMatrixFromNice(NICE::Matrix & niceMatrix)
  164. {
  165. mxArray *matlabMatrix = mxCreateDoubleMatrix( niceMatrix.rows(), niceMatrix.cols(), mxREAL );
  166. double* matlabMatrixPtr = mxGetPr( matlabMatrix );
  167. for( int i = 0; i < niceMatrix.rows(); i++ )
  168. {
  169. for( int j = 0; j < niceMatrix.cols(); j++ )
  170. {
  171. matlabMatrixPtr[i + j*niceMatrix.rows()] = niceMatrix(i,j);
  172. }
  173. }
  174. return matlabMatrix;
  175. }
  176. NICE::Matrix convertDoubleMatrixToNice(const mxArray* matlabMatrix)
  177. {
  178. if( !mxIsDouble( matlabMatrix ) )
  179. mexErrMsgIdAndTxt( "mexnice:error","Expected double in convertDoubleMatrixToNice" );
  180. const mwSize *dims;
  181. int dimx, dimy, numdims;
  182. //figure out dimensions
  183. dims = mxGetDimensions( matlabMatrix );
  184. numdims = mxGetNumberOfDimensions( matlabMatrix );
  185. dimy = (int)dims[0];
  186. dimx = (int)dims[1];
  187. double* ptr = mxGetPr( matlabMatrix );
  188. NICE::Matrix niceMatrix(ptr, dimy, dimx, NICE::Matrix::external);
  189. return niceMatrix;
  190. }
  191. mxArray* convertVectorFromNice(NICE::Vector & niceVector)
  192. {
  193. mxArray *matlabVector = mxCreateDoubleMatrix( niceVector.size(), 1, mxREAL );
  194. double* matlabVectorPtr = mxGetPr( matlabVector );
  195. for( int i = 0; i < niceVector.size(); i++ )
  196. {
  197. matlabVectorPtr[i] = niceVector[i];
  198. }
  199. return matlabVector;
  200. }
  201. NICE::Vector convertDoubleVectorToNice( const mxArray* matlabMatrix )
  202. {
  203. if( !mxIsDouble( matlabMatrix ) )
  204. mexErrMsgIdAndTxt( "mexnice:error","Expected double in convertDoubleVectorToNice" );
  205. const mwSize *dims;
  206. int dimx, dimy, numdims;
  207. //figure out dimensions
  208. dims = mxGetDimensions( matlabMatrix );
  209. numdims = mxGetNumberOfDimensions( matlabMatrix );
  210. dimy = (int)dims[0];
  211. dimx = (int)dims[1];
  212. double* ptr = mxGetPr( matlabMatrix );
  213. if( (dimx != 1) && (dimy != 1) )
  214. mexErrMsgIdAndTxt("mexnice:error","Vector expected");
  215. int dim = std::max(dimx, dimy);
  216. NICE::Vector niceVector( dim, 0.0 );
  217. for( int i = 0; i < dim; i++ )
  218. {
  219. niceVector(i) = ptr[i];
  220. }
  221. return niceVector;
  222. }
  223. std::string convertMatlabToString( const mxArray *matlabString )
  224. {
  225. if( !mxIsChar( matlabString ) )
  226. mexErrMsgIdAndTxt("mexnice:error","Expected string");
  227. char *cstring = mxArrayToString( matlabString );
  228. std::string s( cstring );
  229. mxFree(cstring);
  230. return s;
  231. }
  232. int convertMatlabToInt32( const mxArray *matlabInt32 )
  233. {
  234. if( !mxIsInt32( matlabInt32 ) )
  235. mexErrMsgIdAndTxt("mexnice:error","Expected int32");
  236. int* ptr = (int*) mxGetData( matlabInt32 );
  237. return ptr[0];
  238. }
  239. double convertMatlabToDouble( const mxArray *matlabDouble )
  240. {
  241. if( !mxIsDouble(matlabDouble) )
  242. mexErrMsgIdAndTxt("mexnice:error","Expected double");
  243. double* ptr = (double*) mxGetData( matlabDouble );
  244. return ptr[0];
  245. }
  246. bool convertMatlabToBool(const mxArray *matlabBool)
  247. {
  248. if( !mxIsLogical( matlabBool ) )
  249. mexErrMsgIdAndTxt("mexnice:error","Expected bool");
  250. bool* ptr = (bool*) mxGetData( matlabBool );
  251. return ptr[0];
  252. }
  253. NICE::Config parseParameters(const mxArray *prhs[], int nrhs)
  254. {
  255. NICE::Config conf;
  256. // if first argument is the filename of an existing config file,
  257. // read the config accordingly
  258. int i_start ( 0 );
  259. std::string variable = convertMatlabToString(prhs[i_start]);
  260. if(variable == "conf")
  261. {
  262. conf = NICE::Config ( convertMatlabToString( prhs[i_start+1] ) );
  263. i_start = i_start+2;
  264. }
  265. // now run over all given parameter specifications
  266. // and add them to the config
  267. for( int i=i_start; i < nrhs; i+=2 )
  268. {
  269. std::string variable = convertMatlabToString(prhs[i]);
  270. /////////////////////////////////////////
  271. // READ STANDARD BOOLEAN VARIABLES
  272. /////////////////////////////////////////
  273. if( (variable == "verboseTime") || (variable == "verbose") ||
  274. (variable == "optimize_noise") || (variable == "uncertaintyPredictionForClassification") ||
  275. (variable == "use_quantization") || (variable == "ils_verbose")
  276. )
  277. {
  278. if ( mxIsChar( prhs[i+1] ) )
  279. {
  280. string value = convertMatlabToString( prhs[i+1] );
  281. if ( (value != "true") && (value != "false") )
  282. {
  283. std::string errorMsg = "Unexpected parameter value for \'" + variable + "\'. In string modus, \'true\' or \'false\' expected.";
  284. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  285. }
  286. if( value == "true" )
  287. conf.sB("GPHIKClassifier", variable, true);
  288. else
  289. conf.sB("GPHIKClassifier", variable, false);
  290. }
  291. else if ( mxIsLogical( prhs[i+1] ) )
  292. {
  293. bool value = convertMatlabToBool( prhs[i+1] );
  294. conf.sB("GPHIKClassifier", variable, value);
  295. }
  296. else
  297. {
  298. std::string errorMsg = "Unexpected parameter value for \'" + variable + "\'. \'true\', \'false\', or logical expected.";
  299. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  300. }
  301. }
  302. /////////////////////////////////////////
  303. // READ STANDARD INT VARIABLES
  304. /////////////////////////////////////////
  305. if ( (variable == "nrOfEigenvaluesToConsiderForVarApprox")
  306. )
  307. {
  308. if ( mxIsDouble( prhs[i+1] ) )
  309. {
  310. double value = convertMatlabToDouble(prhs[i+1]);
  311. conf.sI("GPHIKClassifier", variable, (int) value);
  312. }
  313. else if ( mxIsInt32( prhs[i+1] ) )
  314. {
  315. int value = convertMatlabToInt32(prhs[i+1]);
  316. conf.sI("GPHIKClassifier", variable, value);
  317. }
  318. else
  319. {
  320. std::string errorMsg = "Unexpected parameter value for \'" + variable + "\'. Int32 or Double expected.";
  321. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  322. }
  323. }
  324. /////////////////////////////////////////
  325. // READ STRICT POSITIVE INT VARIABLES
  326. /////////////////////////////////////////
  327. if ( (variable == "num_bins") || (variable == "ils_max_iterations")
  328. )
  329. {
  330. if ( mxIsDouble( prhs[i+1] ) )
  331. {
  332. double value = convertMatlabToDouble(prhs[i+1]);
  333. if( value < 1 )
  334. {
  335. std::string errorMsg = "Expected parameter value larger than 0 for \'" + variable + "\'.";
  336. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  337. }
  338. conf.sI("GPHIKClassifier", variable, (int) value);
  339. }
  340. else if ( mxIsInt32( prhs[i+1] ) )
  341. {
  342. int value = convertMatlabToInt32(prhs[i+1]);
  343. if( value < 1 )
  344. {
  345. std::string errorMsg = "Expected parameter value larger than 0 for \'" + variable + "\'.";
  346. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  347. }
  348. conf.sI("GPHIKClassifier", variable, value);
  349. }
  350. else
  351. {
  352. std::string errorMsg = "Unexpected parameter value for \'" + variable + "\'. Int32 or Double expected.";
  353. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  354. }
  355. }
  356. /////////////////////////////////////////
  357. // READ POSITIVE DOUBLE VARIABLES
  358. /////////////////////////////////////////
  359. if ( (variable == "ils_min_delta") || (variable == "ils_min_residual") ||
  360. (variable == "noise")
  361. )
  362. {
  363. if ( mxIsDouble( prhs[i+1] ) )
  364. {
  365. double value = convertMatlabToDouble(prhs[i+1]);
  366. if( value < 0.0 )
  367. {
  368. std::string errorMsg = "Expected parameter value larger than 0 for \'" + variable + "\'.";
  369. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  370. }
  371. conf.sD("GPHIKClassifier", variable, value);
  372. }
  373. else
  374. {
  375. std::string errorMsg = "Unexpected parameter value for \'" + variable + "\'. Double expected.";
  376. mexErrMsgIdAndTxt( "mexnice:error", errorMsg.c_str() );
  377. }
  378. }
  379. /////////////////////////////////////////
  380. // READ REMAINING SPECIFIC VARIABLES
  381. /////////////////////////////////////////
  382. if(variable == "ils_method")
  383. {
  384. string value = convertMatlabToString(prhs[i+1]);
  385. if(value != "CG" && value != "CGL" && value != "SYMMLQ" && value != "MINRES")
  386. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'ils_method\'. \'CG\', \'CGL\', \'SYMMLQ\' or \'MINRES\' expected.");
  387. conf.sS("GPHIKClassifier", variable, value);
  388. }
  389. if(variable == "optimization_method")
  390. {
  391. string value = convertMatlabToString(prhs[i+1]);
  392. if(value != "greedy" && value != "downhillsimplex" && value != "none")
  393. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'optimization_method\'. \'greedy\', \'downhillsimplex\' or \'none\' expected.");
  394. conf.sS("GPHIKClassifier", variable, value);
  395. }
  396. if(variable == "transform")
  397. {
  398. string value = convertMatlabToString( prhs[i+1] );
  399. if(value != "absexp" && value != "exp" && value != "MKL" && value != "WeightedDim")
  400. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'transform\'. \'absexp\', \'exp\' , \'MKL\' or \'WeightedDim\' expected.");
  401. conf.sS("GPHIKClassifier", variable, value);
  402. }
  403. if(variable == "varianceApproximation")
  404. {
  405. string value = convertMatlabToString(prhs[i+1]);
  406. if(value != "approximate_fine" && value != "approximate_rough" && value != "exact" && value != "none")
  407. mexErrMsgIdAndTxt("mexnice:error","Unexpected parameter value for \'varianceApproximation\'. \'approximate_fine\', \'approximate_rough\', \'none\' or \'exact\' expected.");
  408. conf.sS("GPHIKClassifier", variable, value);
  409. }
  410. }
  411. return conf;
  412. }
  413. // MAIN MATLAB FUNCTION
  414. void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
  415. {
  416. // get the command string specifying what to do
  417. if (nrhs < 1)
  418. mexErrMsgTxt("No commands and options passed... Aborting!");
  419. if( !mxIsChar( prhs[0] ) )
  420. mexErrMsgTxt("First argument needs to be the command, ie.e, the class method to call... Aborting!");
  421. std::string cmd = convertMatlabToString( prhs[0] );
  422. // create object
  423. if ( !strcmp("new", cmd.c_str() ) )
  424. {
  425. // check output variable
  426. if (nlhs != 1)
  427. mexErrMsgTxt("New: One output expected.");
  428. // read config settings
  429. NICE::Config conf = parseParameters(prhs+1,nrhs-1);
  430. // create class instance
  431. NICE::GPHIKClassifier * classifier = new NICE::GPHIKClassifier ( &conf, "GPHIKClassifier" /*sectionName in config*/ );
  432. // handle to the C++ instance
  433. plhs[0] = convertPtr2Mat<NICE::GPHIKClassifier>( classifier );
  434. return;
  435. }
  436. // in all other cases, there should be a second input,
  437. // which the be the class instance handle
  438. if (nrhs < 2)
  439. mexErrMsgTxt("Second input should be a class instance handle.");
  440. // delete object
  441. if ( !strcmp("delete", cmd.c_str() ) )
  442. {
  443. // Destroy the C++ object
  444. destroyObject<NICE::GPHIKClassifier>(prhs[1]);
  445. return;
  446. }
  447. // get the class instance pointer from the second input
  448. // every following function needs the classifier object
  449. NICE::GPHIKClassifier * classifier = convertMat2Ptr<NICE::GPHIKClassifier>(prhs[1]);
  450. ////////////////////////////////////////
  451. // Check which class method to call //
  452. ////////////////////////////////////////
  453. // standard train - assumes initialized object
  454. if (!strcmp("train", cmd.c_str() ))
  455. {
  456. // Check parameters
  457. if (nlhs < 0 || nrhs < 4)
  458. {
  459. mexErrMsgTxt("Train: Unexpected arguments.");
  460. }
  461. //------------- read the data --------------
  462. std::vector< const NICE::SparseVector *> examplesTrain;
  463. NICE::Vector yMultiTrain;
  464. if ( mxIsSparse( prhs[2] ) )
  465. {
  466. examplesTrain = convertSparseMatrixToNice( prhs[2] );
  467. }
  468. else
  469. {
  470. NICE::Matrix dataTrain;
  471. dataTrain = convertDoubleMatrixToNice(prhs[2]);
  472. //----------------- convert data to sparse data structures ---------
  473. examplesTrain.resize( dataTrain.rows() );
  474. std::vector< const NICE::SparseVector *>::iterator exTrainIt = examplesTrain.begin();
  475. for (int i = 0; i < (int)dataTrain.rows(); i++, exTrainIt++)
  476. {
  477. *exTrainIt = new NICE::SparseVector( dataTrain.getRow(i) );
  478. }
  479. }
  480. yMultiTrain = convertDoubleVectorToNice(prhs[3]);
  481. //----------------- train our classifier -------------
  482. classifier->train ( examplesTrain , yMultiTrain );
  483. //----------------- clean up -------------
  484. for(int i=0;i<examplesTrain.size();i++)
  485. delete examplesTrain[i];
  486. return;
  487. }
  488. // Classify
  489. if ( !strcmp("classify", cmd.c_str() ) )
  490. {
  491. // Check parameters
  492. if ( (nlhs < 0) || (nrhs < 2) )
  493. {
  494. mexErrMsgTxt("Test: Unexpected arguments.");
  495. }
  496. //------------- read the data --------------
  497. int result;
  498. NICE::SparseVector scores;
  499. double uncertainty;
  500. if ( mxIsSparse( prhs[2] ) )
  501. {
  502. NICE::SparseVector * example;
  503. example = new NICE::SparseVector ( convertSparseVectorToNice( prhs[2] ) );
  504. classifier->classify ( example, result, scores, uncertainty );
  505. //----------------- clean up -------------
  506. delete example;
  507. }
  508. else
  509. {
  510. NICE::Vector * example;
  511. example = new NICE::Vector ( convertDoubleVectorToNice(prhs[2]) );
  512. classifier->classify ( example, result, scores, uncertainty );
  513. //----------------- clean up -------------
  514. delete example;
  515. }
  516. // output
  517. plhs[0] = mxCreateDoubleScalar( result );
  518. if(nlhs >= 2)
  519. {
  520. plhs[1] = convertSparseVectorFromNice( scores, true /*b_adaptIndex*/);
  521. }
  522. if(nlhs >= 3)
  523. {
  524. plhs[2] = mxCreateDoubleScalar( uncertainty );
  525. }
  526. return;
  527. }
  528. // Uncertainty prediction
  529. if ( !strcmp("uncertainty", cmd.c_str() ) )
  530. {
  531. // Check parameters
  532. if ( (nlhs < 0) || (nrhs < 2) )
  533. {
  534. mexErrMsgTxt("Test: Unexpected arguments.");
  535. }
  536. double uncertainty;
  537. //------------- read the data --------------
  538. if ( mxIsSparse( prhs[2] ) )
  539. {
  540. NICE::SparseVector * example;
  541. example = new NICE::SparseVector ( convertSparseVectorToNice( prhs[2] ) );
  542. classifier->predictUncertainty( example, uncertainty );
  543. //----------------- clean up -------------
  544. delete example;
  545. }
  546. else
  547. {
  548. NICE::Vector * example;
  549. example = new NICE::Vector ( convertDoubleVectorToNice(prhs[2]) );
  550. classifier->predictUncertainty( example, uncertainty );
  551. //----------------- clean up -------------
  552. delete example;
  553. }
  554. // output
  555. plhs[0] = mxCreateDoubleScalar( uncertainty );
  556. return;
  557. }
  558. // Test - evaluate classifier on whole test set
  559. if ( !strcmp("test", cmd.c_str() ) )
  560. {
  561. // Check parameters
  562. if (nlhs < 0 || nrhs < 4)
  563. mexErrMsgTxt("Test: Unexpected arguments.");
  564. //------------- read the data --------------
  565. bool dataIsSparse ( mxIsSparse( prhs[2] ) );
  566. std::vector< const NICE::SparseVector *> dataTest_sparse;
  567. NICE::Matrix dataTest_dense;
  568. if ( dataIsSparse )
  569. {
  570. dataTest_sparse = convertSparseMatrixToNice( prhs[2] );
  571. }
  572. else
  573. {
  574. dataTest_dense = convertDoubleMatrixToNice(prhs[2]);
  575. }
  576. NICE::Vector yMultiTest;
  577. yMultiTest = convertDoubleVectorToNice(prhs[3]);
  578. // ------------------------------------------
  579. // ------------- PREPARATION --------------
  580. // ------------------------------------------
  581. // determine classes known during training and corresponding mapping
  582. // thereby allow for non-continous class labels
  583. std::set<int> classesKnownTraining = classifier->getKnownClassNumbers();
  584. int noClassesKnownTraining ( classesKnownTraining.size() );
  585. std::map<int,int> mapClNoToIdxTrain;
  586. std::set<int>::const_iterator clTrIt = classesKnownTraining.begin();
  587. for ( int i=0; i < noClassesKnownTraining; i++, clTrIt++ )
  588. mapClNoToIdxTrain.insert ( std::pair<int,int> ( *clTrIt, i ) );
  589. // determine classes known during testing and corresponding mapping
  590. // thereby allow for non-continous class labels
  591. std::set<int> classesKnownTest;
  592. classesKnownTest.clear();
  593. // determine which classes we have in our label vector
  594. // -> MATLAB: myClasses = unique(y);
  595. for ( NICE::Vector::const_iterator it = yMultiTest.begin(); it != yMultiTest.end(); it++ )
  596. {
  597. if ( classesKnownTest.find ( *it ) == classesKnownTest.end() )
  598. {
  599. classesKnownTest.insert ( *it );
  600. }
  601. }
  602. int noClassesKnownTest ( classesKnownTest.size() );
  603. std::map<int,int> mapClNoToIdxTest;
  604. std::set<int>::const_iterator clTestIt = classesKnownTest.begin();
  605. for ( int i=0; i < noClassesKnownTest; i++, clTestIt++ )
  606. mapClNoToIdxTest.insert ( std::pair<int,int> ( *clTestIt, i ) );
  607. int i_numTestSamples;
  608. if ( dataIsSparse )
  609. i_numTestSamples = dataTest_sparse.size();
  610. else
  611. i_numTestSamples = (int) dataTest_dense.rows();
  612. NICE::Matrix confusionMatrix( noClassesKnownTraining, noClassesKnownTest, 0.0);
  613. NICE::Matrix scores( i_numTestSamples, noClassesKnownTraining, 0.0);
  614. // ------------------------------------------
  615. // ------------- CLASSIFICATION --------------
  616. // ------------------------------------------
  617. NICE::Timer t;
  618. double testTime (0.0);
  619. for (int i = 0; i < i_numTestSamples; i++)
  620. {
  621. //----------------- convert data to sparse data structures ---------
  622. int result;
  623. NICE::SparseVector exampleScoresSparse;
  624. if ( dataIsSparse )
  625. {
  626. // and classify
  627. t.start();
  628. classifier->classify( dataTest_sparse[ i ], result, exampleScoresSparse );
  629. t.stop();
  630. testTime += t.getLast();
  631. }
  632. else
  633. {
  634. NICE::Vector example ( dataTest_dense.getRow(i) );
  635. // and classify
  636. t.start();
  637. classifier->classify( &example, result, exampleScoresSparse );
  638. t.stop();
  639. testTime += t.getLast();
  640. }
  641. confusionMatrix( mapClNoToIdxTrain.find(result)->second, mapClNoToIdxTest.find(yMultiTest[i])->second ) += 1.0;
  642. int scoreCnt ( 0 );
  643. for ( NICE::SparseVector::const_iterator scoreIt = exampleScoresSparse.begin(); scoreIt != exampleScoresSparse.end(); scoreIt++, scoreCnt++ )
  644. {
  645. scores(i,scoreCnt) = scoreIt->second;
  646. }
  647. }
  648. std::cerr << "Time for testing: " << testTime << std::endl;
  649. // clean up
  650. if ( dataIsSparse )
  651. {
  652. for ( std::vector<const NICE::SparseVector *>::iterator it = dataTest_sparse.begin(); it != dataTest_sparse.end(); it++)
  653. delete *it;
  654. }
  655. confusionMatrix.normalizeColumnsL1();
  656. double recRate = confusionMatrix.trace()/confusionMatrix.cols();
  657. plhs[0] = mxCreateDoubleScalar( recRate );
  658. if(nlhs >= 2)
  659. plhs[1] = convertMatrixFromNice(confusionMatrix);
  660. if(nlhs >= 3)
  661. plhs[2] = convertMatrixFromNice(scores);
  662. return;
  663. }
  664. ///////////////////// INTERFACE ONLINE LEARNABLE /////////////////////
  665. // interface specific methods for incremental extensions
  666. ///////////////////// INTERFACE ONLINE LEARNABLE /////////////////////
  667. // addExample
  668. if ( !strcmp("addExample", cmd.c_str() ) )
  669. {
  670. // Check parameters
  671. if ( (nlhs < 0) || (nrhs < 4) )
  672. {
  673. mexErrMsgTxt("Test: Unexpected arguments.");
  674. }
  675. //------------- read the data --------------
  676. NICE::SparseVector * newExample;
  677. double newLabel;
  678. if ( mxIsSparse( prhs[2] ) )
  679. {
  680. newExample = new NICE::SparseVector ( convertSparseVectorToNice( prhs[2] ) );
  681. }
  682. else
  683. {
  684. NICE::Vector * example;
  685. example = new NICE::Vector ( convertDoubleVectorToNice(prhs[2]) );
  686. newExample = new NICE::SparseVector ( *example );
  687. //----------------- clean up -------------
  688. delete example;
  689. }
  690. newLabel = convertMatlabToDouble( prhs[3] );
  691. // setting performOptimizationAfterIncrement is optional
  692. if ( nrhs > 4 )
  693. {
  694. bool performOptimizationAfterIncrement;
  695. performOptimizationAfterIncrement = convertMatlabToBool( prhs[4] );
  696. classifier->addExample ( newExample, newLabel, performOptimizationAfterIncrement );
  697. }
  698. else
  699. {
  700. classifier->addExample ( newExample, newLabel );
  701. }
  702. //----------------- clean up -------------
  703. delete newExample;
  704. return;
  705. }
  706. // addMultipleExamples
  707. if ( !strcmp("addMultipleExamples", cmd.c_str() ) )
  708. {
  709. // Check parameters
  710. if ( (nlhs < 0) || (nrhs < 4) )
  711. {
  712. mexErrMsgTxt("Test: Unexpected arguments.");
  713. }
  714. //------------- read the data --------------
  715. std::vector< const NICE::SparseVector *> newExamples;
  716. NICE::Vector newLabels;
  717. if ( mxIsSparse( prhs[2] ) )
  718. {
  719. newExamples = convertSparseMatrixToNice( prhs[2] );
  720. }
  721. else
  722. {
  723. NICE::Matrix newData;
  724. newData = convertDoubleMatrixToNice(prhs[2]);
  725. //----------------- convert data to sparse data structures ---------
  726. newExamples.resize( newData.rows() );
  727. std::vector< const NICE::SparseVector *>::iterator exTrainIt = newExamples.begin();
  728. for (int i = 0; i < (int)newData.rows(); i++, exTrainIt++)
  729. {
  730. *exTrainIt = new NICE::SparseVector( newData.getRow(i) );
  731. }
  732. }
  733. newLabels = convertDoubleVectorToNice(prhs[3]);
  734. // setting performOptimizationAfterIncrement is optional
  735. if ( nrhs > 4 )
  736. {
  737. bool performOptimizationAfterIncrement;
  738. performOptimizationAfterIncrement = convertMatlabToBool( prhs[4] );
  739. classifier->addMultipleExamples ( newExamples, newLabels, performOptimizationAfterIncrement );
  740. }
  741. else
  742. {
  743. classifier->addMultipleExamples ( newExamples, newLabels );
  744. }
  745. //----------------- clean up -------------
  746. for ( std::vector< const NICE::SparseVector *>::iterator exIt = newExamples.begin();
  747. exIt != newExamples.end(); exIt++
  748. )
  749. {
  750. delete *exIt;
  751. }
  752. return;
  753. }
  754. ///////////////////// INTERFACE PERSISTENT /////////////////////
  755. // interface specific methods for store and restore
  756. ///////////////////// INTERFACE PERSISTENT /////////////////////
  757. // store the classifier to an external file
  758. if ( !strcmp("store", cmd.c_str() ) || !strcmp("save", cmd.c_str() ) )
  759. {
  760. // Check parameters
  761. if ( nrhs < 3 )
  762. mexErrMsgTxt("store: no destination given.");
  763. std::string s_destination = convertMatlabToString( prhs[2] );
  764. std::filebuf fb;
  765. fb.open ( s_destination.c_str(), ios::out );
  766. std::ostream os(&fb);
  767. //
  768. classifier->store( os );
  769. //
  770. fb.close();
  771. return;
  772. }
  773. // load classifier from external file
  774. if ( !strcmp("restore", cmd.c_str() ) || !strcmp("load", cmd.c_str() ) )
  775. {
  776. // Check parameters
  777. if ( nrhs < 3 )
  778. mexErrMsgTxt("restore: no destination given.");
  779. std::string s_destination = convertMatlabToString( prhs[2] );
  780. std::cerr << " aim at restoring the classifier from " << s_destination << std::endl;
  781. std::filebuf fbIn;
  782. fbIn.open ( s_destination.c_str(), ios::in );
  783. std::istream is (&fbIn);
  784. //
  785. classifier->restore( is );
  786. //
  787. fbIn.close();
  788. return;
  789. }
  790. // Got here, so command not recognized
  791. std::string errorMsg (cmd.c_str() );
  792. errorMsg += " -- command not recognized.";
  793. mexErrMsgTxt( errorMsg.c_str() );
  794. }