GMM.cpp 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170
  1. #ifdef NICE_USELIB_OPENMP
  2. #include <omp.h>
  3. #endif
  4. #include <stdio.h>
  5. #include "GMM.h"
  6. #include <core/vector/Algorithms.h>
  7. #include "vislearning/math/cluster/KMeans.h"
  8. // #define DEBUGGMM
  9. using namespace OBJREC;
  10. using namespace NICE;
  11. using namespace std;
  12. ///////////////////// ///////////////////// /////////////////////
  13. // CONSTRUCTORS / DESTRUCTORS
  14. ///////////////////// ///////////////////// /////////////////////
  15. GMM::GMM() : ClusterAlgorithm()
  16. {
  17. this->i_numOfGaussians = 3;
  18. this->dim = -1;
  19. this->mu.clear();
  20. this->sparse_sigma.clear();
  21. this->priors.clear();
  22. this->sparse_inv_sigma.clear();
  23. this->log_det_sigma.clear();
  24. this->mu2.clear();
  25. this->sparse_sigma2.clear();
  26. this->priors2.clear();
  27. this->b_compareTo2ndGMM = false;
  28. this->maxiter = 200;
  29. this->featsperclass = 0;
  30. this->cdimval = -1; //TODO
  31. this->tau = 10.0;
  32. this->pyramid = false;
  33. srand ( time ( NULL ) );
  34. }
  35. GMM::GMM ( int _numOfGaussians ) : i_numOfGaussians ( _numOfGaussians )
  36. {
  37. this->dim = -1;
  38. this->mu.clear();
  39. this->sparse_sigma.clear();
  40. this->priors.clear();
  41. this->sparse_inv_sigma.clear();
  42. this->log_det_sigma.clear();
  43. this->mu2.clear();
  44. this->sparse_sigma2.clear();
  45. this->priors2.clear();
  46. this->b_compareTo2ndGMM = false;
  47. this->maxiter = 200;
  48. this->featsperclass = 0;
  49. this->tau = 0.0;
  50. this->pyramid = false;
  51. srand ( time ( NULL ) );
  52. }
  53. GMM::GMM ( const Config * _conf, int _numOfGaussians ) : i_numOfGaussians ( _numOfGaussians )
  54. {
  55. this->initFromConfig( _conf );
  56. }
  57. GMM::GMM ( const Config * _conf, const std::string& _confSection )
  58. {
  59. this->initFromConfig( _conf, _confSection );
  60. }
  61. GMM::~GMM()
  62. {
  63. }
  64. void GMM::initFromConfig( const NICE::Config* _conf, const std::string& _confSection )
  65. {
  66. if ( this->i_numOfGaussians < 2 )
  67. this->i_numOfGaussians = _conf->gI ( _confSection, "i_numOfGaussians", 2 );
  68. this->dim = -1;
  69. this->mu.clear();
  70. this->sparse_sigma.clear();
  71. this->priors.clear();
  72. this->sparse_inv_sigma.clear();
  73. this->log_det_sigma.clear();
  74. this->mu2.clear();
  75. this->sparse_sigma2.clear();
  76. this->priors2.clear();
  77. this->b_compareTo2ndGMM = false;
  78. this->maxiter = _conf->gI ( _confSection, "maxiter", 200 );
  79. this->featsperclass = _conf->gI ( _confSection, "featsperclass", 0 );
  80. this->tau = _conf->gD ( _confSection, "tau", 100.0 );
  81. this->pyramid = _conf->gB ( _confSection, "pyramid", false );
  82. srand ( time ( NULL ) );
  83. }
  84. ///////////////////// ///////////////////// /////////////////////
  85. // CLUSTERING STUFF
  86. ///////////////////// ///////////////////// //////////////////
  87. void GMM::computeMixture ( Examples examples )
  88. {
  89. int fsize = ( int ) examples.size();
  90. assert ( fsize >= i_numOfGaussians );
  91. dim = examples[0].second.vec->size();
  92. int samples = fsize;
  93. if ( featsperclass > 0 )
  94. {
  95. samples = featsperclass * i_numOfGaussians;
  96. samples = std::min ( samples, fsize );
  97. }
  98. // Copy data in Matrix
  99. VVector dataset;
  100. cout << "reduced training data for GMM from " << fsize << " features to " << samples << " features.";
  101. for ( int i = 0; i < samples; i++ )
  102. {
  103. int k = rand() % fsize;
  104. NICE::Vector *vec = examples[k].second.vec;
  105. dataset.push_back ( *vec );
  106. }
  107. computeMixture ( dataset );
  108. }
  109. void GMM::computeMixture ( const VVector &DataSet )
  110. {
  111. // Learn the GMM model
  112. assert ( DataSet.size() >= ( uint ) i_numOfGaussians );
  113. //initEMkMeans(DataSet); // initialize the model
  114. srand ( time ( NULL ) );
  115. bool poweroftwo = false;
  116. int power = 1;
  117. while ( power <= i_numOfGaussians )
  118. {
  119. if ( i_numOfGaussians == power )
  120. poweroftwo = true;
  121. power *= 2;
  122. }
  123. if ( poweroftwo && pyramid )
  124. {
  125. initEM ( DataSet ); // initialize the model
  126. int g = 1;
  127. while ( g <= i_numOfGaussians )
  128. {
  129. cout << "g = " << g << endl;
  130. doEM ( DataSet, g );
  131. if ( 2*g <= i_numOfGaussians )
  132. {
  133. for ( int i = g; i < g*2; i++ )
  134. {
  135. mu[i] = mu[i-g];
  136. sparse_sigma[i] = sparse_sigma[i-g];
  137. sparse_inv_sigma[i] = sparse_inv_sigma[i-g];
  138. log_det_sigma[i] = log_det_sigma[i-g];
  139. priors[i] = priors[i-g];
  140. double interval = 1.0;
  141. for ( int k = 0; k < ( int ) mu[i].size(); k++ )
  142. {
  143. interval = mu[i][k];
  144. interval = std::max ( interval, 0.1 );
  145. double r = ( interval * ( ( double ) rand() / ( double ) RAND_MAX ) ) - interval / 2.0;
  146. mu[i][k] += r;
  147. }
  148. }
  149. }
  150. g *= 2;
  151. }
  152. }
  153. else
  154. {
  155. initEMkMeans ( DataSet ); // initialize the model
  156. doEM ( DataSet, i_numOfGaussians );
  157. }
  158. // performs EM
  159. }
  160. inline double diagDeterminant ( const NICE::Vector &sparse_mat )
  161. {
  162. double det = 1.0;
  163. for ( int i = 0; i < ( int ) sparse_mat.size(); i++ )
  164. {
  165. det *= sparse_mat[i];
  166. }
  167. return det;
  168. }
  169. inline double logdiagDeterminant ( const NICE::Vector &sparse_mat )
  170. {
  171. double det = 0.0;
  172. for ( int i = 0; i < ( int ) sparse_mat.size(); i++ )
  173. {
  174. det += log ( sparse_mat[i] );
  175. }
  176. return det;
  177. }
  178. inline NICE::Vector diagInverse ( const NICE::Vector &sparse_mat )
  179. {
  180. NICE::Vector inv ( sparse_mat.size() );
  181. for ( int i = 0; i < ( int ) sparse_mat.size(); i++ )
  182. {
  183. inv[i] = 1.0 / sparse_mat[i];
  184. }
  185. return inv;
  186. }
  187. void GMM::initEMkMeans ( const VVector &DataSet )
  188. {
  189. /*init GMM with k-Means*/
  190. OBJREC::KMeans k ( i_numOfGaussians );
  191. NICE::VVector means;
  192. std::vector<double> weights;
  193. std::vector<int> assignment;
  194. k.cluster ( DataSet, means, weights, assignment );
  195. int nData = DataSet.size();
  196. this->dim = DataSet[0].size();
  197. cdimval = dim * log ( 2 * 3.14159 );
  198. std::vector<int> pop ( i_numOfGaussians, 0 );
  199. priors.resize ( i_numOfGaussians );
  200. mu = VVector ( i_numOfGaussians, dim );
  201. log_det_sigma.clear();
  202. vector<int> allk;
  203. NICE::Vector globmean ( dim );
  204. globmean.set ( 0.0 );
  205. for ( int n = 0; n < ( int ) DataSet.size(); n++ ) /* getting the max value for time */
  206. {
  207. globmean = globmean + DataSet[n];
  208. }
  209. globmean *= ( 1.0 / nData );
  210. NICE::Vector sparse_globsigma;
  211. sparse_globsigma.resize ( dim );
  212. sparse_globsigma.set ( 0.0 );
  213. for ( int i = 0; i < ( int ) DataSet.size(); i++ ) // Covariances updates
  214. {
  215. // nur diagonal Elemente berechnen
  216. for ( int d = 0; d < dim; d++ )
  217. {
  218. double diff = ( DataSet[i][d] - globmean[d] );
  219. sparse_globsigma[d] += diff * diff;
  220. }
  221. }
  222. sparse_globsigma *= ( 1.0 / DataSet.size() );
  223. for ( int i = 0; i < i_numOfGaussians; i++ )
  224. {
  225. NICE::Vector _inv_sigma = diagInverse ( sparse_globsigma );
  226. sparse_sigma.push_back ( sparse_globsigma );
  227. sparse_inv_sigma.push_back ( _inv_sigma );
  228. log_det_sigma.push_back ( logdiagDeterminant ( sparse_globsigma ) );
  229. mu[i] = means[i];
  230. //priors[i]=1.0/(double)i_numOfGaussians; // set equi-probables states
  231. priors[i] = weights[i]; // set kMeans weights
  232. }
  233. }
  234. void GMM::initEM ( const VVector &DataSet )
  235. {
  236. /* init the GaussianMixture by using randomized meanvectors */
  237. int nData = DataSet.size();
  238. this->dim = DataSet[0].size();
  239. cdimval = dim * log ( 2 * 3.14159 );
  240. vector<int> pop ( i_numOfGaussians, 0 );
  241. priors.resize ( i_numOfGaussians );
  242. mu = VVector ( i_numOfGaussians, dim );
  243. log_det_sigma.clear();
  244. vector<int> allk;
  245. NICE::Vector globmean ( dim );
  246. globmean.set ( 0.0 );
  247. for ( int n = 0; n < ( int ) DataSet.size(); n++ ) /* getting the max value for time */
  248. {
  249. globmean = globmean + DataSet[n];
  250. }
  251. globmean *= ( 1.0 / nData );
  252. NICE::Vector sparse_globsigma;
  253. sparse_globsigma.resize ( dim );
  254. sparse_globsigma.set ( 0.0 );
  255. for ( int i = 0; i < ( int ) DataSet.size(); i++ ) // Covariances updates
  256. {
  257. // nur diagonal Elemente berechnen
  258. for ( int d = 0; d < dim; d++ )
  259. {
  260. double diff = ( DataSet[i][d] - globmean[d] );
  261. sparse_globsigma[d] += diff * diff;
  262. }
  263. }
  264. sparse_globsigma *= ( 1.0 / DataSet.size() );
  265. for ( int i = 0; i < i_numOfGaussians; i++ )
  266. {
  267. priors[i] = 1.0 / ( double ) i_numOfGaussians; // set equi-probables states
  268. NICE::Vector _inv_sigma = diagInverse ( sparse_globsigma );
  269. sparse_sigma.push_back ( sparse_globsigma );
  270. sparse_inv_sigma.push_back ( _inv_sigma );
  271. log_det_sigma.push_back ( logdiagDeterminant ( sparse_globsigma ) );
  272. bool newv = false;
  273. int k;
  274. while ( !newv )
  275. {
  276. newv = true;
  277. k = rand() % nData;
  278. for ( int nk = 0; nk < ( int ) allk.size(); nk++ )
  279. if ( allk[nk] == k )
  280. {
  281. newv = false;
  282. }
  283. if ( newv )
  284. allk.push_back ( k );
  285. }
  286. mu[i] = DataSet[k];
  287. }
  288. }
  289. inline void sumRow ( NICE::Matrix mat, NICE::Vector &res )
  290. {
  291. int row = mat.rows();
  292. int column = mat.cols();
  293. res = NICE::Vector ( column );
  294. res.set ( 1e-5f );
  295. //#pragma omp parallel for
  296. for ( int i = 0; i < column; i++ ) {
  297. for ( int j = 0; j < row; j++ ) {
  298. res[i] += mat ( j, i );
  299. }
  300. }
  301. }
  302. double GMM::logpdfState ( const NICE::Vector &Vin, int state )
  303. {
  304. /* get the probability density for a given state and a given vector */
  305. double p;
  306. NICE::Vector dif;
  307. dif = Vin;
  308. dif -= mu[state];
  309. p = 0.0;
  310. for ( int i = 0; i < ( int ) dif.size(); i++ )
  311. {
  312. p += dif[i] * dif[i] * sparse_inv_sigma[state][i];
  313. }
  314. p = -0.5 * ( p + cdimval + log_det_sigma[state] );
  315. return p;
  316. }
  317. int GMM::doEM ( const VVector &DataSet, int nbgaussians )
  318. {
  319. /* perform Expectation/Maximization on the given Dataset :
  320. Matrix DataSet(nSamples,Dimensions).
  321. The GaussianMixture Object must be initialised before
  322. (see initEM or initEMkMeans methods ) */
  323. #ifdef DEBUG
  324. cerr << "GMM::start EM" << endl;
  325. #endif
  326. int nData = DataSet.size();
  327. int iter = 0;
  328. double log_lik;
  329. double log_lik_threshold = 1e-6f;
  330. double log_lik_old = -1e10f;
  331. NICE::Matrix unity ( dim, dim, 0.0 );
  332. for ( int k = 0; k < dim; k++ )
  333. unity ( k, k ) = 1.0;
  334. //EM loop
  335. while ( true )
  336. {
  337. #ifdef DEBUGGMM
  338. cerr << "GMM::EM: iteration: " << iter << endl;
  339. #endif
  340. vector<double> sum_p;
  341. sum_p.resize ( nData );
  342. for ( int i = 0; i < nData; i++ )
  343. {
  344. sum_p[i] = 0.0;
  345. }
  346. NICE::Matrix pxi ( nData, i_numOfGaussians );
  347. pxi.set ( 0.0 );
  348. NICE::Matrix pix ( nData, i_numOfGaussians );
  349. pix.set ( 0.0 );
  350. NICE::Vector E;
  351. iter++;
  352. if ( iter > maxiter ) {
  353. cerr << "EM stops here. Max number of iterations (" << maxiter << ") has been reached." << endl;
  354. return iter;
  355. }
  356. double sum_log = 0.0;
  357. // computing expectation
  358. double maxp = -numeric_limits<double>::max();
  359. vector<double> maxptmp ( nData, -numeric_limits<double>::max() );
  360. #pragma omp parallel for
  361. for ( int i = 0; i < nData; i++ )
  362. {
  363. for ( int j = 0; j < nbgaussians; j++ )
  364. {
  365. double p = logpdfState ( DataSet[i], j ); // log(P(x|i))
  366. maxptmp[i] = std::max ( maxptmp[i], p );
  367. pxi ( i, j ) = p;
  368. }
  369. }
  370. for ( int i = 0; i < nData; i++ )
  371. {
  372. maxp = std::max ( maxp, maxptmp[i] );
  373. }
  374. #pragma omp parallel for
  375. for ( int i = 0; i < nData; i++ )
  376. {
  377. sum_p[i] = 0.0;
  378. for ( int j = 0; j < nbgaussians; j++ )
  379. {
  380. double p = pxi ( i, j ) - maxp; // log(P(x|i))
  381. p = exp ( p ); // P(x|i)
  382. if ( p < 1e-40 )
  383. p = 1e-40;
  384. pxi ( i, j ) = p;
  385. sum_p[i] += p * priors[j];
  386. }
  387. }
  388. for ( int i = 0; i < nData; i++ )
  389. {
  390. sum_log += log ( sum_p[i] );
  391. }
  392. #pragma omp parallel for
  393. for ( int j = 0; j < nbgaussians; j++ )
  394. {
  395. for ( int i = 0; i < nData; i++ )
  396. {
  397. pix ( i, j ) = ( pxi ( i, j ) * priors[j] ) / sum_p[i]; // then P(i|x)
  398. }
  399. }
  400. // here we compute the log likehood
  401. log_lik = sum_log / nData;
  402. #ifdef DEBUGGMM
  403. cout << "diff: " << fabs ( ( log_lik / log_lik_old ) - 1 ) << " thresh: " << log_lik_threshold << " sum: " << sum_log << endl;
  404. //cout << "logold: " << log_lik_old << " lognew: " << log_lik << endl;
  405. #endif
  406. if ( fabs ( ( log_lik / log_lik_old ) - 1 ) < log_lik_threshold )
  407. {
  408. //if log likehood hasn't move enough, the algorithm has converged, exiting the loop
  409. return iter;
  410. }
  411. log_lik_old = log_lik;
  412. // Update Step
  413. sumRow ( pix, E );
  414. #pragma omp parallel for
  415. for ( int j = 0; j < nbgaussians; j++ )
  416. {
  417. priors[j] = ( E[j] + tau ) / ( nData + tau * nbgaussians ); // new priors
  418. NICE::Vector tmu ( dim );
  419. tmu.set ( 0.0 );
  420. NICE::Vector sparse_tmsigma ( dim );
  421. sparse_tmsigma.set ( 0.0 );
  422. for ( int i = 0; i < nData; i++ ) // Means update loop
  423. {
  424. tmu = tmu + ( DataSet[i] * pix ( i, j ) );
  425. }
  426. NICE::Vector tmu2 = mu[j];
  427. mu[j] = tmu + tau * tmu2;
  428. mu[j] = mu[j] * ( 1.0 / ( E[j] + tau ) );
  429. for ( int i = 0; i < nData; i++ ) // Covariances updates
  430. {
  431. // nur diagonal Elemente berechnen
  432. for ( int d = 0; d < dim; d++ )
  433. {
  434. sparse_tmsigma[d] += DataSet[i][d] * DataSet[i][d] * pix ( i, j );
  435. }
  436. }
  437. NICE::Vector sparse_tmsigma2 ( dim );
  438. for ( int d = 0; d < dim; d++ )
  439. {
  440. sparse_tmsigma[d] += 1e-5f;
  441. sparse_tmsigma2[d] = sparse_sigma[j][d] + tmu2[d] * tmu2[d];
  442. sparse_sigma[j][d] = ( sparse_tmsigma[d] + tau * sparse_tmsigma2[d] ) / ( E[j] + tau ) - ( mu[j][d] * mu[j][d] );
  443. }
  444. sparse_inv_sigma[j] = diagInverse ( sparse_sigma[j] );
  445. log_det_sigma[j] = logdiagDeterminant ( sparse_sigma[j] );
  446. }
  447. if ( b_compareTo2ndGMM )
  448. {
  449. double dist = this->compareTo2ndGMM();
  450. std::cout << "dist for iteration " << iter << std::endl;
  451. std::cout << "d: " << dist << std::endl;
  452. }
  453. }
  454. #ifdef DEBUG
  455. cerr << "GMM::finished EM after " << iter << " iterations" << endl;
  456. #endif
  457. return iter;
  458. }
  459. int GMM::getBestClass ( const NICE::Vector &v, double *bprob )
  460. {
  461. int bestclass = 0;
  462. double maxprob = logpdfState ( v, 0 ) + log ( priors[0] ); // log(P(x|i))
  463. for ( int i = 1; i < i_numOfGaussians; i++ )
  464. {
  465. double prob = logpdfState ( v, i ) + log ( priors[i] ); // log(P(x|i))
  466. if ( prob > maxprob )
  467. {
  468. maxprob = prob;
  469. bestclass = i;
  470. }
  471. }
  472. if ( bprob != NULL )
  473. *bprob = maxprob;
  474. return bestclass;
  475. }
  476. void GMM::getProbs ( const NICE::Vector &vin, SparseVector &outprobs )
  477. {
  478. outprobs.clear();
  479. outprobs.setDim ( i_numOfGaussians );
  480. Vector probs;
  481. getProbs ( vin, probs );
  482. for ( int i = 0; i < i_numOfGaussians; i++ )
  483. {
  484. if ( probs[i] > 1e-7f )
  485. outprobs[i] = probs[i];
  486. }
  487. }
  488. void GMM::getProbs ( const NICE::Vector &vin, Vector &outprobs )
  489. {
  490. Vector probs;
  491. probs.resize ( i_numOfGaussians );
  492. probs.set ( 0.0 );
  493. double probsum = 0.0;
  494. double maxp = -numeric_limits<double>::max();
  495. for ( int i = 0; i < i_numOfGaussians; i++ )
  496. {
  497. probs[i] = logpdfState ( vin, i ) + log ( priors[i] ); // log(P(x|i))
  498. maxp = std::max ( maxp, probs[i] );
  499. }
  500. for ( int i = 0; i < i_numOfGaussians; i++ )
  501. {
  502. probs[i] -= maxp;
  503. probs[i] = exp ( probs[i] );
  504. probsum += probs[i];
  505. }
  506. // normalize probs
  507. #pragma omp parallel for
  508. for ( int i = 0; i < i_numOfGaussians; i++ )
  509. {
  510. probs[i] /= probsum;
  511. }
  512. outprobs = probs;
  513. }
  514. void GMM::getFisher ( const NICE::Vector &vin, SparseVector &outprobs )
  515. {
  516. int size = i_numOfGaussians * 2 * dim;
  517. outprobs.clear();
  518. outprobs.setDim ( size );
  519. int counter = 0;
  520. NICE::Vector classprobs;
  521. classprobs.resize ( i_numOfGaussians );
  522. classprobs.set ( 0.0 );
  523. double maxp = -numeric_limits<double>::max();
  524. for ( int i = 0; i < i_numOfGaussians; i++ )
  525. {
  526. classprobs[i] = logpdfState ( vin, i ) + log ( priors[i] ); // log(P(x|i))
  527. maxp = std::max ( maxp, classprobs[i] );
  528. }
  529. for ( int i = 0; i < i_numOfGaussians; i++ )
  530. {
  531. double p = classprobs[i] - maxp;
  532. p = exp ( p );
  533. for ( int d = 0; d < dim; d++ )
  534. {
  535. double diff = vin[d] - mu[i][d];
  536. double sigma2 = sparse_sigma[i][d] * sparse_sigma[i][d];
  537. double sigma3 = sigma2 * sparse_sigma[i][d];
  538. double normmu = sqrt ( priors[i] / sigma2 );
  539. double normsig = sqrt ( ( 2.0 * priors[i] ) / sigma2 );
  540. double gradmu = ( p * ( diff / sigma2 ) ) / normmu;
  541. double gradsig = ( p * ( ( diff * diff ) / sigma3 - 1.0 / sparse_sigma[i][d] ) ) / normsig;
  542. if ( fabs ( gradmu ) > 1e-7f )
  543. outprobs[counter] = gradmu;
  544. counter++;
  545. if ( fabs ( gradsig ) > 1e-7f )
  546. outprobs[counter] = gradsig;
  547. counter++;
  548. }
  549. }
  550. }
  551. void GMM::cluster ( const VVector & features, VVector & prototypes, vector<double> & weights, vector<int> & assignment )
  552. {
  553. computeMixture ( features );
  554. prototypes.clear();
  555. weights.clear();
  556. assignment.clear();
  557. for ( int i = 0; i < ( int ) features.size(); i++ )
  558. {
  559. int c = getBestClass ( features[i] );
  560. assignment.push_back ( c );
  561. weights.push_back ( priors[c] );
  562. }
  563. for ( int c = 0; c < i_numOfGaussians; c++ )
  564. prototypes.push_back ( mu[c] );
  565. cout << "tau: " << tau << endl;
  566. }
  567. void GMM::saveData ( const std::string filename )
  568. {
  569. ofstream fout ( filename.c_str() );
  570. fout << i_numOfGaussians << " " << dim << endl;
  571. mu >> fout;
  572. fout << endl;
  573. for ( int n = 0; n < i_numOfGaussians; n++ )
  574. {
  575. fout << sparse_sigma[n] << endl;
  576. }
  577. for ( int n = 0; n < i_numOfGaussians; n++ )
  578. {
  579. fout << priors[n] << " ";
  580. }
  581. fout.close();
  582. }
  583. bool GMM::loadData ( const std::string filename )
  584. {
  585. cerr << "read GMM Data" << endl;
  586. ifstream fin ( filename.c_str() );
  587. if ( fin.fail() )
  588. {
  589. cerr << "GMM: Datei " << filename << " nicht gefunden!" << endl;
  590. return false;
  591. }
  592. fin >> i_numOfGaussians;
  593. fin >> dim;
  594. cdimval = log ( pow ( 2 * 3.14159, dim ) );
  595. mu.clear();
  596. for ( int i = 0; i < i_numOfGaussians; i++ )
  597. {
  598. NICE::Vector tmp;
  599. fin >> tmp;
  600. mu.push_back ( tmp );
  601. }
  602. log_det_sigma.clear();
  603. for ( int n = 0; n < i_numOfGaussians; n++ )
  604. {
  605. NICE::Matrix _sigma;
  606. NICE::Vector _sparse_sigma;
  607. _sparse_sigma = NICE::Vector ( dim );
  608. fin >> _sparse_sigma;
  609. sparse_sigma.push_back ( _sparse_sigma );
  610. sparse_inv_sigma.push_back ( diagInverse ( sparse_sigma[n] ) );
  611. log_det_sigma.push_back ( logdiagDeterminant ( sparse_sigma[n] ) );
  612. }
  613. for ( int n = 0; n < i_numOfGaussians; n++ )
  614. {
  615. double tmpd;
  616. fin >> tmpd;
  617. priors.push_back ( tmpd );
  618. }
  619. fin.close();
  620. cerr << "reading GMM Data finished" << endl;
  621. return true;
  622. }
  623. void GMM::getParams ( VVector &mean, VVector &sSigma, vector<double> &p ) const
  624. {
  625. mean = this->mu;
  626. sSigma.resize ( this->i_numOfGaussians );
  627. p.clear();
  628. for ( int i = 0; i < this->i_numOfGaussians; i++ )
  629. {
  630. sSigma[i] = this->sparse_sigma[i];
  631. p.push_back ( this->priors[i] );
  632. }
  633. }
  634. void GMM::setGMMtoCompareWith ( NICE::VVector mean, NICE::VVector sSigma, std::vector<double> p )
  635. {
  636. this->mu2 = mean;
  637. this->sparse_sigma2 = sSigma;
  638. this->priors2 = std::vector<double> ( p );
  639. }
  640. double GMM::kPPK ( NICE::Vector sigma1, NICE::Vector sigma2, NICE::Vector mu1, NICE::Vector mu2, double p ) const
  641. {
  642. double d = mu1.size();
  643. double det1 = 1.0;
  644. double det2 = 1.0;
  645. double det3 = 1.0;
  646. double eval = 0.0;
  647. for ( int i = 0; i < d; i++ )
  648. {
  649. det1 *= sigma1[i];
  650. det2 *= sigma2[i];
  651. double sigma = 1.0 / ( ( 1.0 / sigma1[i] + 1.0 / sigma2[i] ) * p );
  652. det3 *= sigma;
  653. double mu = p * ( mu1[i] * sigma1[i] + mu2[i] * sigma2[i] );
  654. eval += 0.5 * mu * sigma * mu - ( p * mu1[i] * mu1[i] ) / ( 2.0 * sigma1[i] ) - ( p * mu2[i] * mu2[i] ) / ( 2.0 * sigma2[i] );
  655. }
  656. return ( pow ( 2.0*M_PI, ( ( 1.0 - 2.0*p ) *d ) / 2.0 ) * pow ( det1, -p / 2.0 ) * pow ( det2, -p / 2.0 ) * pow ( det3, 0.5 ) * exp ( eval ) );
  657. }
  658. double GMM::compareTo2ndGMM() const
  659. {
  660. double distkij = 0.0;
  661. double distkjj = 0.0;
  662. double distkii = 0.0;
  663. for ( int i = 0; i < i_numOfGaussians; i++ )
  664. {
  665. for ( int j = 0; j < i_numOfGaussians; j++ )
  666. {
  667. double kij = kPPK ( sparse_sigma[i], sparse_sigma2[j], mu[i], mu2[j], 0.5 );
  668. double kii = kPPK ( sparse_sigma[i], sparse_sigma[j], mu[i], mu[j], 0.5 );
  669. double kjj = kPPK ( sparse_sigma2[i], sparse_sigma2[j], mu2[i], mu2[j], 0.5 );
  670. kij = priors[i] * priors2[j] * kij;
  671. kii = priors[i] * priors[j] * kii;
  672. kjj = priors2[i] * priors2[j] * kjj;
  673. distkij += kij;
  674. distkii += kii;
  675. distkjj += kjj;
  676. }
  677. }
  678. return ( distkij / ( sqrt ( distkii ) *sqrt ( distkjj ) ) );
  679. }
  680. void GMM::setCompareTo2ndGMM ( const bool & _compareTo2ndGMM )
  681. {
  682. this->b_compareTo2ndGMM = _compareTo2ndGMM;
  683. }
  684. int GMM::getNumberOfGaussians() const
  685. {
  686. return this->i_numOfGaussians;
  687. }
  688. ///////////////////// INTERFACE PERSISTENT /////////////////////
  689. // interface specific methods for store and restore
  690. ///////////////////// INTERFACE PERSISTENT /////////////////////
  691. void GMM::restore ( std::istream & is, int format )
  692. {
  693. //delete everything we knew so far...
  694. this->clear();
  695. if ( is.good() )
  696. {
  697. std::string tmp;
  698. is >> tmp; //class name
  699. if ( ! this->isStartTag( tmp, "GMM" ) )
  700. {
  701. std::cerr << " WARNING - attempt to restore GMM, but start flag " << tmp << " does not match! Aborting... " << std::endl;
  702. throw;
  703. }
  704. bool b_endOfBlock ( false ) ;
  705. while ( !b_endOfBlock )
  706. {
  707. is >> tmp; // start of block
  708. if ( this->isEndTag( tmp, "GMM" ) )
  709. {
  710. b_endOfBlock = true;
  711. continue;
  712. }
  713. tmp = this->removeStartTag ( tmp );
  714. if ( tmp.compare("i_numOfGaussians") == 0 )
  715. {
  716. is >> this->i_numOfGaussians;
  717. is >> tmp; // end of block
  718. tmp = this->removeEndTag ( tmp );
  719. }
  720. else if ( tmp.compare("dim") == 0 )
  721. {
  722. is >> this->dim;
  723. is >> tmp; // end of block
  724. tmp = this->removeEndTag ( tmp );
  725. }
  726. else if ( tmp.compare("mu") == 0 )
  727. {
  728. this->mu.clear();
  729. this->mu.setIoUntilEndOfFile ( false );
  730. this->mu.restore ( is, format );
  731. is >> tmp; // end of block
  732. tmp = this->removeEndTag ( tmp );
  733. }
  734. else if ( tmp.compare("sparse_sigma") == 0 )
  735. {
  736. this->sparse_sigma.clear();
  737. this->sparse_sigma.setIoUntilEndOfFile ( false );
  738. this->sparse_sigma.restore ( is, format );
  739. is >> tmp; // end of block
  740. tmp = this->removeEndTag ( tmp );
  741. }
  742. else if ( tmp.compare("priors") == 0 )
  743. {
  744. int sizeOfPriors;
  745. is >> sizeOfPriors;
  746. this->priors.resize ( sizeOfPriors );
  747. for ( std::vector< double >::iterator itPriors = this->priors.begin();
  748. itPriors != this->priors.end();
  749. itPriors++
  750. )
  751. {
  752. is >> *itPriors;
  753. }
  754. is >> tmp; // end of block
  755. tmp = this->removeEndTag ( tmp );
  756. }
  757. else if ( tmp.compare("sparse_inv_sigma") == 0 )
  758. {
  759. this->sparse_inv_sigma.clear();
  760. this->sparse_inv_sigma.setIoUntilEndOfFile ( false );
  761. this->sparse_inv_sigma.restore ( is, format );
  762. is >> tmp; // end of block
  763. tmp = this->removeEndTag ( tmp );
  764. }
  765. else if ( tmp.compare("log_det_sigma") == 0 )
  766. {
  767. int sizeOfLogDetSigma;
  768. is >> sizeOfLogDetSigma;
  769. this->log_det_sigma.resize ( sizeOfLogDetSigma );
  770. for ( std::vector< double >::iterator itLogDetSigma = this->log_det_sigma.begin();
  771. itLogDetSigma != this->log_det_sigma.end();
  772. itLogDetSigma++
  773. )
  774. {
  775. is >> *itLogDetSigma;
  776. }
  777. is >> tmp; // end of block
  778. tmp = this->removeEndTag ( tmp );
  779. }
  780. else if ( tmp.compare("mu2") == 0 )
  781. {
  782. this->mu2.clear();
  783. this->mu2.setIoUntilEndOfFile ( false );
  784. this->mu2.restore ( is, format );
  785. is >> tmp; // end of block
  786. tmp = this->removeEndTag ( tmp );
  787. }
  788. else if ( tmp.compare("sparse_sigma2") == 0 )
  789. {
  790. this->sparse_sigma2.clear();
  791. this->sparse_sigma2.setIoUntilEndOfFile ( false );
  792. this->sparse_sigma2.restore ( is, format );
  793. is >> tmp; // end of block
  794. tmp = this->removeEndTag ( tmp );
  795. }
  796. else if ( tmp.compare("priors2") == 0 )
  797. {
  798. int sizeOfPriors2;
  799. is >> sizeOfPriors2;
  800. this->priors2.resize ( sizeOfPriors2 );
  801. for ( std::vector< double >::iterator itPriors2 = this->priors2.begin();
  802. itPriors2 != this->priors2.end();
  803. itPriors2++
  804. )
  805. {
  806. is >> *itPriors2;
  807. }
  808. is >> tmp; // end of block
  809. tmp = this->removeEndTag ( tmp );
  810. }
  811. else if ( tmp.compare("b_compareTo2ndGMM") == 0 )
  812. {
  813. is >> this->b_compareTo2ndGMM;
  814. is >> tmp; // end of block
  815. tmp = this->removeEndTag ( tmp );
  816. }
  817. else if ( tmp.compare("maxiter") == 0 )
  818. {
  819. is >> this->maxiter;
  820. is >> tmp; // end of block
  821. tmp = this->removeEndTag ( tmp );
  822. }
  823. else if ( tmp.compare("featsperclass") == 0 )
  824. {
  825. is >> this->featsperclass;
  826. is >> tmp; // end of block
  827. tmp = this->removeEndTag ( tmp );
  828. }
  829. else if ( tmp.compare("cdimval") == 0 )
  830. {
  831. is >> this->cdimval;
  832. is >> tmp; // end of block
  833. tmp = this->removeEndTag ( tmp );
  834. }
  835. else if ( tmp.compare("tau") == 0 )
  836. {
  837. is >> this->tau;
  838. is >> tmp; // end of block
  839. tmp = this->removeEndTag ( tmp );
  840. }
  841. else if ( tmp.compare("pyramid") == 0 )
  842. {
  843. is >> this->pyramid;
  844. is >> tmp; // end of block
  845. tmp = this->removeEndTag ( tmp );
  846. }
  847. else
  848. {
  849. std::cerr << "WARNING -- unexpected GMM object -- " << tmp << " -- for restoration... aborting" << std::endl;
  850. throw;
  851. }
  852. }
  853. }
  854. else
  855. {
  856. std::cerr << "GMM::restore -- InStream not initialized - restoring not possible!" << std::endl;
  857. throw;
  858. }
  859. }
  860. void GMM::store ( std::ostream & os, int format ) const
  861. {
  862. if (os.good())
  863. {
  864. // show starting point
  865. os << this->createStartTag( "GMM" ) << std::endl;
  866. os << this->createStartTag( "i_numOfGaussians" ) << std::endl;
  867. os << this->i_numOfGaussians << std::endl;
  868. os << this->createEndTag( "i_numOfGaussians" ) << std::endl;
  869. if ( this->dim != -1 )
  870. {
  871. os << this->createStartTag( "dim" ) << std::endl;
  872. os << this->dim << std::endl;
  873. os << this->createEndTag( "dim" ) << std::endl;
  874. }
  875. if ( this->mu.size() > 0 )
  876. {
  877. os << this->createStartTag( "mu" ) << std::endl;
  878. this->mu.store ( os, format );
  879. os << this->createEndTag( "mu" ) << std::endl;
  880. }
  881. if ( this->sparse_sigma.size() > 0 )
  882. {
  883. os << this->createStartTag( "sparse_sigma" ) << std::endl;
  884. this->sparse_sigma.store ( os, format );
  885. os << this->createEndTag( "sparse_sigma" ) << std::endl;
  886. }
  887. if ( this->priors.size() > 0 )
  888. {
  889. os << this->createStartTag( "priors" ) << std::endl;
  890. os << this->priors.size () << std::endl;
  891. for ( std::vector< double >::const_iterator itPriors = this->priors.begin();
  892. itPriors != this->priors.end();
  893. itPriors++
  894. )
  895. {
  896. os << *itPriors;
  897. }
  898. os << std::endl;
  899. os << this->createEndTag( "priors" ) << std::endl;
  900. }
  901. if ( this->sparse_inv_sigma.size() > 0 )
  902. {
  903. os << this->createStartTag( "sparse_inv_sigma" ) << std::endl;
  904. this->sparse_inv_sigma.store ( os, format );
  905. os << this->createEndTag( "sparse_inv_sigma" ) << std::endl;
  906. }
  907. if ( this->log_det_sigma.size() > 0 )
  908. {
  909. os << this->createStartTag( "log_det_sigma" ) << std::endl;
  910. os << this->log_det_sigma.size ();
  911. for ( std::vector< double >::const_iterator itLogDetSigma = this->log_det_sigma.begin();
  912. itLogDetSigma != this->log_det_sigma.end();
  913. itLogDetSigma++
  914. )
  915. {
  916. os << *itLogDetSigma;
  917. }
  918. os << std::endl;
  919. os << this->createEndTag( "log_det_sigma" ) << std::endl;
  920. }
  921. if ( this->mu2.size() > 0 )
  922. {
  923. os << this->createStartTag( "mu2" ) << std::endl;
  924. this->mu2.store ( os, format );
  925. os << this->createEndTag( "mu2" ) << std::endl;
  926. }
  927. if ( this->sparse_sigma2.size() > 0 )
  928. {
  929. os << this->createStartTag( "sparse_sigma2" ) << std::endl;
  930. this->sparse_sigma2.store ( os, format );
  931. os << this->createEndTag( "sparse_sigma2" ) << std::endl;
  932. }
  933. if ( this->priors2.size() > 0 )
  934. {
  935. os << this->createStartTag( "priors2" ) << std::endl;
  936. os << this->priors2.size () << std::endl;
  937. for ( std::vector< double >::const_iterator itPriors2 = this->priors2.begin();
  938. itPriors2 != this->priors2.end();
  939. itPriors2++
  940. )
  941. {
  942. os << *itPriors2;
  943. }
  944. os << std::endl;
  945. os << this->createEndTag( "priors2" ) << std::endl;
  946. }
  947. os << this->createStartTag( "b_compareTo2ndGMM" ) << std::endl;
  948. os << this->b_compareTo2ndGMM << std::endl;
  949. os << this->createEndTag( "b_compareTo2ndGMM" ) << std::endl;
  950. os << this->createStartTag( "maxiter" ) << std::endl;
  951. os << this->maxiter << std::endl;
  952. os << this->createEndTag( "maxiter" ) << std::endl;
  953. os << this->createStartTag( "featsperclass" ) << std::endl;
  954. os << this->featsperclass << std::endl;
  955. os << this->createEndTag( "featsperclass" ) << std::endl;
  956. if ( cdimval != -1 )
  957. {
  958. os << this->createStartTag( "cdimval" ) << std::endl;
  959. os << this->cdimval << std::endl;
  960. os << this->createEndTag( "cdimval" ) << std::endl;
  961. }
  962. os << this->createStartTag( "tau" ) << std::endl;
  963. os << this->tau << std::endl;
  964. os << this->createEndTag( "tau" ) << std::endl;
  965. os << this->createStartTag( "pyramid" ) << std::endl;
  966. os << this->pyramid << std::endl;
  967. os << this->createEndTag( "pyramid" ) << std::endl;
  968. // done
  969. os << this->createEndTag( "GMM" ) << std::endl;
  970. }
  971. else
  972. {
  973. std::cerr << "OutStream not initialized - storing not possible!" << std::endl;
  974. }
  975. }
  976. void GMM::clear ()
  977. {
  978. }