GMM.h 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. /**
  2. * @file GMM.h
  3. * @brief Gaussian Mixture Model based on
  4. article{Calinon07SMC,
  5. title="On Learning, Representing and Generalizing a Task in a Humanoid
  6. Robot",
  7. author="S. Calinon and F. Guenter and A. Billard",
  8. journal="IEEE Transactions on Systems, Man and Cybernetics, Part B.
  9. Special issue on robot learning by observation, demonstration and
  10. imitation",
  11. year="2007",
  12. volume="37",
  13. number="2",
  14. pages="286--298"
  15. }
  16. * @author Björn Fröhlich
  17. * @date 05/14/2009
  18. */
  19. #ifndef GMMINCLUDE
  20. #define GMMINCLUDE
  21. #include "core/vector/VectorT.h"
  22. #include "core/vector/MatrixT.h"
  23. #include "vislearning/cbaselib/MultiDataset.h"
  24. #include "vislearning/cbaselib/LocalizationResult.h"
  25. #include "vislearning/cbaselib/CachedExample.h"
  26. #include "vislearning/cbaselib/Example.h"
  27. #include "vislearning/math/cluster/ClusterAlgorithm.h"
  28. #include "core/vector/VVector.h"
  29. namespace OBJREC {
  30. class GMM : public ClusterAlgorithm
  31. {
  32. protected:
  33. //! number of gaussians
  34. int gaussians;
  35. //! dimension of each feature
  36. int dim;
  37. //! mean vectors
  38. NICE::VVector mu;
  39. //! sparse sigma vectors
  40. NICE::VVector sparse_sigma;
  41. //! sparse inverse sigma vectors (if usediag)
  42. NICE::VVector sparse_inv_sigma;
  43. //! save det_sigma for fast computing
  44. std::vector<double> log_det_sigma;
  45. //! the configfile
  46. const NICE::Config *conf;
  47. //! the weight for each gaussian
  48. std::vector<double> priors;
  49. //! parameters for other GM to compare
  50. NICE::VVector mu2;
  51. NICE::VVector sparse_sigma2;
  52. std::vector<double> priors2;
  53. bool comp;
  54. //! maximum number of iterations for EM
  55. int maxiter;
  56. //! how many features to use, use 0 for alle input features
  57. int featsperclass;
  58. //! for faster computing cdimval = dim*2*PI
  59. double cdimval;
  60. //! parameter for the map estimation
  61. double tau;
  62. //! whether to use pyramid initialisation or not
  63. bool pyramid;
  64. public:
  65. /**
  66. * simplest constructor
  67. */
  68. GMM();
  69. /**
  70. * simple constructor
  71. * @param _no_classes
  72. */
  73. GMM ( int _no_classes );
  74. /**
  75. * standard constructor
  76. * @param conf a Configfile
  77. * @param _no_classes number of gaussian
  78. */
  79. GMM ( const NICE::Config *conf, int _no_classes = -1 );
  80. /**
  81. * standard destructor
  82. */
  83. ~GMM() {
  84. std::cerr << "dadada" << std::endl;
  85. };
  86. /**
  87. * computes the mixture
  88. * @param examples the input features
  89. */
  90. void computeMixture ( Examples examples );
  91. /**
  92. * computes the mixture
  93. * @param DataSet the input features
  94. */
  95. void computeMixture ( const NICE::VVector &DataSet );
  96. /**
  97. * returns the probabilities for each gaussian in a sparse vector
  98. * @param vin input vector
  99. * @param probs BoV output vector
  100. */
  101. void getProbs ( const NICE::Vector &vin, NICE::SparseVector &probs );
  102. /**
  103. * returns the probabilities for each gaussian
  104. * @param vin input vector
  105. * @param probs BoV output vector
  106. */
  107. void getProbs ( const NICE::Vector &vin, NICE::Vector &probs );
  108. /**
  109. * returns the fisher score for the gmm
  110. * @param vin input vector
  111. * @param probs Fisher score output vector
  112. */
  113. void getFisher ( const NICE::Vector &vin, NICE::SparseVector &probs );
  114. /**
  115. * init the GaussianMixture by selecting randomized mean vectors and using the coovariance of all features
  116. * @param DataSet input Matrix
  117. */
  118. void initEM ( const NICE::VVector &DataSet );
  119. /**
  120. * alternative for initEM: init the GaussianMixture with a K-Means clustering
  121. * @param DataSet input Matrix
  122. */
  123. void initEMkMeans ( const NICE::VVector &DataSet );
  124. /**
  125. * performs Expectation Maximization on the Dataset, in order to obtain a nState GMM Dataset is a Matrix(no_classes,nDimensions)
  126. * @param DataSet input Matrix
  127. * @param gaussians number gaussians to use
  128. * @return number of iterations
  129. */
  130. int doEM ( const NICE::VVector &DataSet, int nbgaussians );
  131. /**
  132. * Compute log probabilty of vector v for the given state. *
  133. * @param Vin
  134. * @param state
  135. * @return
  136. */
  137. double logpdfState ( const NICE::Vector &Vin, int state );
  138. /**
  139. * determine the best mixture for the input feature
  140. * @param v input feature
  141. * @param bprob probability of the best mixture
  142. * @return numer of the best mixture
  143. */
  144. int getBestClass ( const NICE::Vector &v, double *bprob = NULL );
  145. /**
  146. * Cluster a given Set of features and return the labels for each feature
  147. * @param features input features
  148. * @param prototypes mean of the best gaussian
  149. * @param weights weight of the best gaussian
  150. * @param assignment number of the best gaussian
  151. */
  152. void cluster ( const NICE::VVector & features, NICE::VVector & prototypes, std::vector<double> & weights, std::vector<int> & assignment );
  153. /**
  154. * save GMM data
  155. * @param filename filename
  156. */
  157. void saveData ( const std::string filename );
  158. /**
  159. * load GMM data
  160. * @param filename filename
  161. * @return true if everything works fine
  162. */
  163. bool loadData ( const std::string filename );
  164. /**
  165. * return the parameter of the mixture
  166. * @param mu
  167. * @param sSigma
  168. * @param p
  169. */
  170. void getParams ( NICE::VVector &mean, NICE::VVector &sSigma, std::vector<double> &p );
  171. /**
  172. * Set the parameters of an other mixture for comparing with this one
  173. * @param mean mean vectors
  174. * @param sSigma diagonal covariance Matrixs
  175. * @param p weights
  176. */
  177. void setCompareGM ( NICE::VVector mean, NICE::VVector sSigma, std::vector<double> p );
  178. /**
  179. * probability product kernel
  180. * @param sigma1
  181. * @param sigma2
  182. * @param mu1
  183. * @param mu2
  184. * @param p
  185. * @return
  186. */
  187. double kPPK ( NICE::Vector sigma1, NICE::Vector sigma2, NICE::Vector mu1, NICE::Vector mu2, double p );
  188. /**
  189. * starts a comparison between this Mixture and a other one seted bei "setComparGM"
  190. */
  191. double compare();
  192. /**
  193. * whether to compare or not
  194. * @param c
  195. */
  196. void comparing ( bool c = true );
  197. int getSize() {
  198. return gaussians;
  199. }
  200. };
  201. } // namespace
  202. #endif