GMM.h 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. /**
  2. * @file GMM.h
  3. * @brief Gaussian Mixture Model based on
  4. article{Calinon07SMC,
  5. title="On Learning, Representing and Generalizing a Task in a Humanoid
  6. Robot",
  7. author="S. Calinon and F. Guenter and A. Billard",
  8. journal="IEEE Transactions on Systems, Man and Cybernetics, Part B.
  9. Special issue on robot learning by observation, demonstration and
  10. imitation",
  11. year="2007",
  12. volume="37",
  13. number="2",
  14. pages="286--298"
  15. }
  16. * @author Björn Fröhlich, Alexander Freytag
  17. * @date 05/14/2009
  18. */
  19. #ifndef GMMINCLUDE
  20. #define GMMINCLUDE
  21. #include "core/vector/VectorT.h"
  22. #include "core/vector/MatrixT.h"
  23. #include "vislearning/cbaselib/MultiDataset.h"
  24. #include "vislearning/cbaselib/LocalizationResult.h"
  25. #include "vislearning/cbaselib/CachedExample.h"
  26. #include "vislearning/cbaselib/Example.h"
  27. #include "vislearning/math/cluster/ClusterAlgorithm.h"
  28. #include "core/vector/VVector.h"
  29. namespace OBJREC {
  30. class GMM : public ClusterAlgorithm
  31. {
  32. protected:
  33. //! number of gaussians
  34. int i_numOfGaussians;
  35. //! dimension of each feature
  36. int dim;
  37. //! mean vectors
  38. NICE::VVector mu;
  39. //! sparse sigma vectors
  40. NICE::VVector sparse_sigma;
  41. //! the weight for each gaussian
  42. std::vector<double> priors;
  43. //! sparse inverse sigma vectors (if usediag)
  44. NICE::VVector sparse_inv_sigma;
  45. //! save det_sigma for fast computing
  46. std::vector<double> log_det_sigma;
  47. //! parameters for other GM to compare
  48. NICE::VVector mu2;
  49. NICE::VVector sparse_sigma2;
  50. std::vector<double> priors2;
  51. bool b_compareTo2ndGMM;
  52. //! maximum number of iterations for EM
  53. int maxiter;
  54. //! how many features to use, use 0 for alle input features
  55. int featsperclass;
  56. //! for faster computing cdimval = dim*2*PI
  57. double cdimval;
  58. //! parameter for the map estimation
  59. double tau;
  60. //! whether to use pyramid initialisation or not
  61. bool pyramid;
  62. public:
  63. ///////////////////// ///////////////////// /////////////////////
  64. // CONSTRUCTORS / DESTRUCTORS
  65. ///////////////////// ///////////////////// /////////////////////
  66. /**
  67. * simplest constructor
  68. */
  69. GMM();
  70. /**
  71. * simple constructor
  72. * @param _numOfGaussians
  73. */
  74. GMM ( int _numOfGaussians );
  75. /**
  76. * standard constructor
  77. * @param conf a Configfile
  78. * @param _numOfGaussians number of gaussian
  79. */
  80. GMM ( const NICE::Config *conf, int _numOfGaussians );
  81. /**
  82. * @brief recommended constructor
  83. * @author Alexander Freytag
  84. * @date 14-02-2014 ( dd-mm-yyyy )
  85. * @param _conf a Configfile
  86. * @param _confSection tag specifying the part in the config file
  87. */
  88. GMM ( const NICE::Config * _conf, const std::string & _confSection = "GMM" );
  89. /**
  90. * standard destructor
  91. */
  92. ~GMM();
  93. /**
  94. * @brief Jobs previously performed in the config-version of the constructor, read settings etc.
  95. * @author Alexander Freytag
  96. * @date 13-02-2014 ( dd-mm-yyyy )
  97. */
  98. void initFromConfig ( const NICE::Config * _conf, const std::string & _confSection = "GMM");
  99. ///////////////////// ///////////////////// /////////////////////
  100. // CLUSTERING STUFF
  101. ///////////////////// ///////////////////// //////////////////
  102. /**
  103. * computes the mixture
  104. * @param examples the input features
  105. */
  106. void computeMixture ( Examples examples );
  107. /**
  108. * computes the mixture
  109. * @param DataSet the input features
  110. */
  111. void computeMixture ( const NICE::VVector &DataSet );
  112. /**
  113. * returns the probabilities for each gaussian in a sparse vector
  114. * @param vin input vector
  115. * @param probs BoV output vector
  116. */
  117. void getProbs ( const NICE::Vector &vin, NICE::SparseVector &probs );
  118. /**
  119. * returns the probabilities for each gaussian
  120. * @param vin input vector
  121. * @param probs BoV output vector
  122. */
  123. void getProbs ( const NICE::Vector &vin, NICE::Vector &probs );
  124. /**
  125. * returns the fisher score for the gmm
  126. * @param vin input vector
  127. * @param probs Fisher score output vector
  128. */
  129. void getFisher ( const NICE::Vector &vin, NICE::SparseVector &probs );
  130. /**
  131. * init the GaussianMixture by selecting randomized mean vectors and using the coovariance of all features
  132. * @param DataSet input Matrix
  133. */
  134. void initEM ( const NICE::VVector &DataSet );
  135. /**
  136. * alternative for initEM: init the GaussianMixture with a K-Means clustering
  137. * @param DataSet input Matrix
  138. */
  139. void initEMkMeans ( const NICE::VVector &DataSet );
  140. /**
  141. * performs Expectation Maximization on the Dataset, in order to obtain a nState GMM Dataset is a Matrix(no_classes,nDimensions)
  142. * @param DataSet input Matrix
  143. * @param gaussians number gaussians to use
  144. * @return number of iterations
  145. */
  146. int doEM ( const NICE::VVector &DataSet, int nbgaussians );
  147. /**
  148. * Compute log probabilty of vector v for the given state. *
  149. * @param Vin
  150. * @param state
  151. * @return
  152. */
  153. double logpdfState ( const NICE::Vector &Vin, int state );
  154. /**
  155. * determine the best mixture for the input feature
  156. * @param v input feature
  157. * @param bprob probability of the best mixture
  158. * @return numer of the best mixture
  159. */
  160. int getBestClass ( const NICE::Vector &v, double *bprob = NULL );
  161. /**
  162. * Cluster a given Set of features and return the labels for each feature
  163. * @param features input features
  164. * @param prototypes mean of the best gaussian
  165. * @param weights weight of the best gaussian
  166. * @param assignment number of the best gaussian
  167. */
  168. void cluster ( const NICE::VVector & features, NICE::VVector & prototypes, std::vector<double> & weights, std::vector<int> & assignment );
  169. /**
  170. * save GMM data
  171. * @param filename filename
  172. * @NOTE deprecated, use methods inherited from persistent instead!
  173. */
  174. void saveData ( const std::string filename );
  175. /**
  176. * load GMM data
  177. * @param filename filename
  178. * @return true if everything works fine
  179. * @NOTE deprecated, use methods inherited from persistent instead!
  180. */
  181. bool loadData ( const std::string filename );
  182. /**
  183. * return the parameter of the current mixture model
  184. * @param mu
  185. * @param sSigma
  186. * @param p
  187. */
  188. void getParams ( NICE::VVector &mean, NICE::VVector &sSigma, std::vector<double> &p ) const;
  189. /**
  190. * Set the parameters of an other mixture for comparing with this one
  191. * @param mean mean vectors
  192. * @param sSigma diagonal covariance Matrixs
  193. * @param p weights
  194. */
  195. void setGMMtoCompareWith ( NICE::VVector mean, NICE::VVector sSigma, std::vector<double> p );
  196. /**
  197. * probability product kernel
  198. * @param sigma1
  199. * @param sigma2
  200. * @param mu1
  201. * @param mu2
  202. * @param p
  203. * @return
  204. */
  205. double kPPK ( NICE::Vector sigma1, NICE::Vector sigma2, NICE::Vector mu1, NICE::Vector mu2, double p ) const;
  206. /**
  207. * starts a comparison between this Mixture and another one set bei "setGMMtoCompareWith"
  208. */
  209. double compareTo2ndGMM() const;
  210. /**
  211. * whether to compare or not
  212. * @param _compareTo2ndGMM
  213. */
  214. void setCompareTo2ndGMM ( const bool & _compareTo2ndGMM = true );
  215. int getNumberOfGaussians() const;
  216. ///////////////////// INTERFACE PERSISTENT /////////////////////
  217. // interface specific methods for store and restore
  218. ///////////////////// INTERFACE PERSISTENT /////////////////////
  219. /**
  220. * @brief Load object from external file (stream)
  221. * @author Alexander Freytag
  222. * @date 13-02-2014 ( dd-mm-yyyy )
  223. */
  224. void restore ( std::istream & is, int format = 0 );
  225. /**
  226. * @brief Save object to external file (stream)
  227. * @author Alexander Freytag
  228. * @date 13-02-2014 ( dd-mm-yyyy )
  229. */
  230. void store ( std::ostream & os, int format = 0 ) const;
  231. /**
  232. * @brief Clear object
  233. * @author Alexander Freytag
  234. * @date 13-02-2014 ( dd-mm-yyyy )
  235. */
  236. void clear ();
  237. };
  238. } // namespace
  239. #endif