MutualInformation.cpp 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. /**
  2. * @file MutualInformation.cpp
  3. * @brief Part Selection with Mutual Information
  4. * @author Erik Rodner
  5. * @date 02/20/2008
  6. */
  7. #include <iostream>
  8. #include "MutualInformation.h"
  9. #include "vislearning/baselib/Gnuplot.h"
  10. using namespace OBJREC;
  11. using namespace std;
  12. using namespace NICE;
  13. MutualInformation::MutualInformation( bool verbose )
  14. {
  15. this->verbose = verbose;
  16. }
  17. MutualInformation::~MutualInformation()
  18. {
  19. }
  20. void MutualInformation::addStatistics ( const vector<Vector *> & v, size_t d, double threshold, size_t & ones ) const
  21. {
  22. for ( vector<Vector *>::const_iterator j = v.begin() ; j != v.end(); j++ )
  23. if ( (*(*j))[d] > threshold ) ones++;
  24. }
  25. double MutualInformation::entropy ( size_t n1, size_t n2 ) const
  26. {
  27. // - p_1 log p_1 - p_2 log p_2
  28. // aber log im bereich [0,1] numerisch instabil
  29. // daher
  30. // p_1 log 1/p_1 + p_2 log 1/p_2
  31. double sum = n1 + n2;
  32. double log1 = n1 > 0 ? log( (double)n1) : 0;
  33. double log2 = n2 > 0 ? log( (double)n2) : 0;
  34. double logsum = (sum > 0) ? log(sum) : 0;
  35. return - (n1/sum*log1 + n2/sum*log2) + logsum;
  36. }
  37. double MutualInformation::mutualInformationOverall ( const LabeledSetVector & ls,
  38. size_t dimension,
  39. double threshold ) const
  40. {
  41. double entropy_conditional = 0.0;
  42. for ( LabeledSetVector::const_iterator i = ls.begin();
  43. i != ls.end();
  44. i++ )
  45. {
  46. size_t ones = 0;
  47. addStatistics ( i->second, dimension, threshold, ones );
  48. double entropy_conditional_class = entropy ( ones, i->second.size() - ones );
  49. entropy_conditional += entropy_conditional_class;
  50. }
  51. entropy_conditional /= ls.size();
  52. double entropy_joint = 0.0;
  53. size_t ones = 0;
  54. size_t count = 0;
  55. for ( LabeledSetVector::const_iterator i = ls.begin();
  56. i != ls.end();
  57. i++ )
  58. {
  59. addStatistics ( i->second, dimension, threshold, ones );
  60. count += i->second.size();
  61. }
  62. entropy_joint = entropy ( ones, count - ones );
  63. return entropy_joint - entropy_conditional;
  64. }
  65. double MutualInformation::mutualInformationClass ( const LabeledSetVector & ls,
  66. size_t classno,
  67. size_t dimension,
  68. double threshold ) const
  69. {
  70. size_t ones_p = 0;
  71. LabeledSetVector::const_iterator iclassno = ls.find(classno);
  72. if ( iclassno == ls.end() )
  73. {
  74. fprintf (stderr, "MutualInformation::mutualInformationClass: classno %u not found\n", classno );
  75. exit(-1);
  76. }
  77. size_t count_p = iclassno->second.size();
  78. addStatistics ( iclassno->second, dimension, threshold, ones_p );
  79. double entropy_conditional_p = entropy ( ones_p, count_p - ones_p );
  80. size_t ones_n = 0;
  81. size_t count_n = 0;
  82. for ( LabeledSetVector::const_iterator i = ls.begin();
  83. i != ls.end();
  84. i++ )
  85. {
  86. if ( i->first != (int)classno );
  87. addStatistics ( i->second, dimension, threshold, ones_n );
  88. count_n += i->second.size();
  89. fprintf(stderr,"This exception means: review code and check the statement 'if ( i->first != (int)classno );' in line 108 / file MutualInformation.cpp. It contains an empty control statement and might be an error. If nonetheless desired behavoir, delete this exception throw.");
  90. }
  91. double entropy_conditional_n = entropy ( ones_n, count_n - ones_n );
  92. double entropy_conditional = 0.5 * ( entropy_conditional_n + entropy_conditional_p );
  93. double entropy_joint = entropy ( ones_p + ones_n, count_p + count_n - ones_p - ones_n );
  94. return entropy_joint - entropy_conditional;
  95. }
  96. double MutualInformation::computeThresholdClass ( const LabeledSetVector & ls, size_t classno,
  97. size_t dimension, double & opt_threshold ) const
  98. {
  99. vector<double> thresholds;
  100. LOOP_ALL(ls)
  101. {
  102. EACH(classno, v);
  103. double val = v[dimension];
  104. thresholds.push_back ( val );
  105. }
  106. sort ( thresholds.begin(), thresholds.end() );
  107. thresholds.erase( std::unique( thresholds.begin(), thresholds.end()), thresholds.end());
  108. opt_threshold = 0.0;
  109. double opt_mi = 0.0;
  110. for ( vector<double>::const_iterator i = thresholds.begin();
  111. i != thresholds.end();
  112. i++ )
  113. {
  114. vector<double>::const_iterator j = i + 1;
  115. if ( j == thresholds.end() ) break;
  116. double threshold = 0.5 * ((*i) + (*j));
  117. double mi = mutualInformationClass ( ls, classno, dimension, threshold );
  118. if ( mi > opt_mi ) {
  119. opt_mi = mi;
  120. opt_threshold = threshold;
  121. }
  122. }
  123. return opt_mi;
  124. }
  125. double MutualInformation::computeThresholdOverall ( const LabeledSetVector & ls, size_t dimension, double & opt_threshold ) const
  126. {
  127. vector<double> thresholds;
  128. vector<int> y;
  129. LOOP_ALL(ls)
  130. {
  131. EACH(classno, v);
  132. double val = v[dimension];
  133. thresholds.push_back ( val );
  134. y.push_back(classno);
  135. }
  136. sort ( thresholds.begin(), thresholds.end() );
  137. thresholds.erase( std::unique( thresholds.begin(), thresholds.end()), thresholds.end());
  138. opt_threshold = 0.0;
  139. double opt_mi = 0.0;
  140. uint ind = 0;
  141. for ( vector<double>::const_iterator i = thresholds.begin();
  142. i != thresholds.end(); i++, ind++ )
  143. {
  144. vector<double>::const_iterator j = i + 1;
  145. if ( j == thresholds.end() ) break;
  146. // the optimimum can not be found at non-class borders
  147. if ( y[ind] == y[ind+1] ) continue;
  148. double threshold = 0.5 * ((*i) + (*j));
  149. // FIXME: This call is pretty inefficient!!
  150. // We can directly count the features here...might be 100times faster :)
  151. double mi = mutualInformationOverall ( ls, dimension, threshold );
  152. if ( mi > opt_mi ) {
  153. opt_mi = mi;
  154. opt_threshold = threshold;
  155. }
  156. }
  157. return opt_mi;
  158. }
  159. void MutualInformation::computeThresholdsClass ( const LabeledSetVector & ls, size_t classno,
  160. NICE::Vector & thresholds, NICE::Vector & mis ) const
  161. {
  162. size_t max_dimension = ls.dimension();
  163. thresholds.resize(max_dimension);
  164. mis.resize(max_dimension);
  165. for ( size_t k = 0 ; k < max_dimension ; k++ )
  166. {
  167. double t, mi;
  168. mi = computeThresholdClass ( ls, classno, k, t );
  169. mis[k] = mi;
  170. thresholds[k] = t;
  171. }
  172. }
  173. void MutualInformation::computeThresholdsOverall ( const LabeledSetVector & ls, NICE::Vector & thresholds, NICE::Vector & mis ) const
  174. {
  175. size_t max_dimension = ls.dimension();
  176. thresholds.resize(max_dimension);
  177. mis.resize(max_dimension);
  178. for ( size_t k = 0 ; k < max_dimension ; k++ )
  179. {
  180. if ( verbose )
  181. cerr << "MutualInformation: Optimizing threshold for feature " << k << " / " << max_dimension << endl;
  182. double t, mi;
  183. mi = computeThresholdOverall ( ls, k, t );
  184. mis[k] = mi;
  185. thresholds[k] = t;
  186. }
  187. }