MutualInformation.cpp 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. /**
  2. * @file MutualInformation.cpp
  3. * @brief Part Selection with Mutual Information
  4. * @author Erik Rodner
  5. * @date 02/20/2008
  6. */
  7. #include <iostream>
  8. #include "MutualInformation.h"
  9. #include "vislearning/baselib/Gnuplot.h"
  10. using namespace OBJREC;
  11. using namespace std;
  12. using namespace NICE;
  13. MutualInformation::MutualInformation( bool verbose )
  14. {
  15. this->verbose = verbose;
  16. }
  17. MutualInformation::~MutualInformation()
  18. {
  19. }
  20. void MutualInformation::addStatistics ( const vector<Vector *> & v, size_t d, double threshold, size_t & ones ) const
  21. {
  22. for ( vector<Vector *>::const_iterator j = v.begin() ; j != v.end(); j++ )
  23. if ( (*(*j))[d] > threshold ) ones++;
  24. }
  25. double MutualInformation::entropy ( size_t n1, size_t n2 ) const
  26. {
  27. // - p_1 log p_1 - p_2 log p_2
  28. // aber log im bereich [0,1] numerisch instabil
  29. // daher
  30. // p_1 log 1/p_1 + p_2 log 1/p_2
  31. double sum = n1 + n2;
  32. double log1 = n1 > 0 ? log(n1) : 0;
  33. double log2 = n2 > 0 ? log(n2) : 0;
  34. double logsum = (sum > 0) ? log(sum) : 0;
  35. return - (n1/sum*log1 + n2/sum*log2) + logsum;
  36. }
  37. double MutualInformation::mutualInformationOverall ( const LabeledSetVector & ls,
  38. size_t dimension,
  39. double threshold ) const
  40. {
  41. double entropy_conditional = 0.0;
  42. for ( LabeledSetVector::const_iterator i = ls.begin();
  43. i != ls.end();
  44. i++ )
  45. {
  46. size_t ones = 0;
  47. addStatistics ( i->second, dimension, threshold, ones );
  48. double entropy_conditional_class = entropy ( ones, i->second.size() - ones );
  49. entropy_conditional += entropy_conditional_class;
  50. }
  51. entropy_conditional /= ls.size();
  52. double entropy_joint = 0.0;
  53. size_t ones = 0;
  54. size_t count = 0;
  55. for ( LabeledSetVector::const_iterator i = ls.begin();
  56. i != ls.end();
  57. i++ )
  58. {
  59. addStatistics ( i->second, dimension, threshold, ones );
  60. count += i->second.size();
  61. }
  62. entropy_joint = entropy ( ones, count - ones );
  63. return entropy_joint - entropy_conditional;
  64. }
  65. double MutualInformation::mutualInformationClass ( const LabeledSetVector & ls,
  66. size_t classno,
  67. size_t dimension,
  68. double threshold ) const
  69. {
  70. size_t ones_p = 0;
  71. LabeledSetVector::const_iterator iclassno = ls.find(classno);
  72. if ( iclassno == ls.end() )
  73. {
  74. fprintf (stderr, "MutualInformation::mutualInformationClass: classno %u not found\n", classno );
  75. exit(-1);
  76. }
  77. size_t count_p = iclassno->second.size();
  78. addStatistics ( iclassno->second, dimension, threshold, ones_p );
  79. double entropy_conditional_p = entropy ( ones_p, count_p - ones_p );
  80. size_t ones_n = 0;
  81. size_t count_n = 0;
  82. for ( LabeledSetVector::const_iterator i = ls.begin();
  83. i != ls.end();
  84. i++ )
  85. {
  86. if ( i->first != (int)classno );
  87. addStatistics ( i->second, dimension, threshold, ones_n );
  88. count_n += i->second.size();
  89. }
  90. double entropy_conditional_n = entropy ( ones_n, count_n - ones_n );
  91. double entropy_conditional = 0.5 * ( entropy_conditional_n + entropy_conditional_p );
  92. double entropy_joint = entropy ( ones_p + ones_n, count_p + count_n - ones_p - ones_n );
  93. return entropy_joint - entropy_conditional;
  94. }
  95. double MutualInformation::computeThresholdClass ( const LabeledSetVector & ls, size_t classno,
  96. size_t dimension, double & opt_threshold ) const
  97. {
  98. vector<double> thresholds;
  99. LOOP_ALL(ls)
  100. {
  101. EACH(classno, v);
  102. double val = v[dimension];
  103. thresholds.push_back ( val );
  104. }
  105. sort ( thresholds.begin(), thresholds.end() );
  106. thresholds.erase( std::unique( thresholds.begin(), thresholds.end()), thresholds.end());
  107. opt_threshold = 0.0;
  108. double opt_mi = 0.0;
  109. for ( vector<double>::const_iterator i = thresholds.begin();
  110. i != thresholds.end();
  111. i++ )
  112. {
  113. vector<double>::const_iterator j = i + 1;
  114. if ( j == thresholds.end() ) break;
  115. double threshold = 0.5 * ((*i) + (*j));
  116. double mi = mutualInformationClass ( ls, classno, dimension, threshold );
  117. if ( mi > opt_mi ) {
  118. opt_mi = mi;
  119. opt_threshold = threshold;
  120. }
  121. }
  122. return opt_mi;
  123. }
  124. double MutualInformation::computeThresholdOverall ( const LabeledSetVector & ls, size_t dimension, double & opt_threshold ) const
  125. {
  126. vector<double> thresholds;
  127. vector<int> y;
  128. LOOP_ALL(ls)
  129. {
  130. EACH(classno, v);
  131. double val = v[dimension];
  132. thresholds.push_back ( val );
  133. y.push_back(classno);
  134. }
  135. sort ( thresholds.begin(), thresholds.end() );
  136. thresholds.erase( std::unique( thresholds.begin(), thresholds.end()), thresholds.end());
  137. opt_threshold = 0.0;
  138. double opt_mi = 0.0;
  139. uint ind = 0;
  140. for ( vector<double>::const_iterator i = thresholds.begin();
  141. i != thresholds.end(); i++, ind++ )
  142. {
  143. vector<double>::const_iterator j = i + 1;
  144. if ( j == thresholds.end() ) break;
  145. // the optimimum can not be found at non-class borders
  146. if ( y[ind] == y[ind+1] ) continue;
  147. double threshold = 0.5 * ((*i) + (*j));
  148. // FIXME: This call is pretty inefficient!!
  149. // We can directly count the features here...might be 100times faster :)
  150. double mi = mutualInformationOverall ( ls, dimension, threshold );
  151. if ( mi > opt_mi ) {
  152. opt_mi = mi;
  153. opt_threshold = threshold;
  154. }
  155. }
  156. return opt_mi;
  157. }
  158. void MutualInformation::computeThresholdsClass ( const LabeledSetVector & ls, size_t classno,
  159. NICE::Vector & thresholds, NICE::Vector & mis ) const
  160. {
  161. size_t max_dimension = ls.dimension();
  162. thresholds.resize(max_dimension);
  163. mis.resize(max_dimension);
  164. for ( size_t k = 0 ; k < max_dimension ; k++ )
  165. {
  166. double t, mi;
  167. mi = computeThresholdClass ( ls, classno, k, t );
  168. mis[k] = mi;
  169. thresholds[k] = t;
  170. }
  171. }
  172. void MutualInformation::computeThresholdsOverall ( const LabeledSetVector & ls, NICE::Vector & thresholds, NICE::Vector & mis ) const
  173. {
  174. size_t max_dimension = ls.dimension();
  175. thresholds.resize(max_dimension);
  176. mis.resize(max_dimension);
  177. for ( size_t k = 0 ; k < max_dimension ; k++ )
  178. {
  179. if ( verbose )
  180. cerr << "MutualInformation: Optimizing threshold for feature " << k << " / " << max_dimension << endl;
  181. double t, mi;
  182. mi = computeThresholdOverall ( ls, k, t );
  183. mis[k] = mi;
  184. thresholds[k] = t;
  185. }
  186. }