MutualInformation.cpp 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. /**
  2. * @file MutualInformation.cpp
  3. * @brief Part Selection with Mutual Information
  4. * @author Erik Rodner
  5. * @date 02/20/2008
  6. */
  7. #include <iostream>
  8. #include "MutualInformation.h"
  9. #include "vislearning/baselib/Gnuplot.h"
  10. using namespace OBJREC;
  11. using namespace std;
  12. using namespace NICE;
  13. MutualInformation::MutualInformation()
  14. {
  15. }
  16. MutualInformation::~MutualInformation()
  17. {
  18. }
  19. void MutualInformation::addStatistics ( const vector<Vector *> & v, size_t d, double threshold, size_t & ones ) const
  20. {
  21. for ( vector<Vector *>::const_iterator j = v.begin() ; j != v.end(); j++ )
  22. if ( (*(*j))[d] > threshold ) ones++;
  23. }
  24. double MutualInformation::entropy ( size_t n1, size_t n2 ) const
  25. {
  26. // - p_1 log p_1 - p_2 log p_2
  27. // aber log im bereich [0,1] numerisch instabil
  28. // daher
  29. // p_1 log 1/p_1 + p_2 log 1/p_2
  30. double sum = n1 + n2;
  31. double log1 = n1 > 0 ? log(n1) : 0;
  32. double log2 = n2 > 0 ? log(n2) : 0;
  33. double logsum = (sum > 0) ? log(sum) : 0;
  34. return - (n1/sum*log1 + n2/sum*log2) + logsum;
  35. }
  36. double MutualInformation::mutualInformationOverall ( const LabeledSetVector & ls,
  37. size_t dimension,
  38. double threshold ) const
  39. {
  40. double entropy_conditional = 0.0;
  41. for ( LabeledSetVector::const_iterator i = ls.begin();
  42. i != ls.end();
  43. i++ )
  44. {
  45. size_t ones = 0;
  46. addStatistics ( i->second, dimension, threshold, ones );
  47. double entropy_conditional_class = entropy ( ones, i->second.size() - ones );
  48. entropy_conditional += entropy_conditional_class;
  49. }
  50. entropy_conditional /= ls.size();
  51. double entropy_joint = 0.0;
  52. size_t ones = 0;
  53. size_t count = 0;
  54. for ( LabeledSetVector::const_iterator i = ls.begin();
  55. i != ls.end();
  56. i++ )
  57. {
  58. addStatistics ( i->second, dimension, threshold, ones );
  59. count += i->second.size();
  60. }
  61. entropy_joint = entropy ( ones, count - ones );
  62. return entropy_joint - entropy_conditional;
  63. }
  64. double MutualInformation::mutualInformationClass ( const LabeledSetVector & ls,
  65. size_t classno,
  66. size_t dimension,
  67. double threshold ) const
  68. {
  69. size_t ones_p = 0;
  70. LabeledSetVector::const_iterator iclassno = ls.find(classno);
  71. if ( iclassno == ls.end() )
  72. {
  73. fprintf (stderr, "MutualInformation::mutualInformationClass: classno %u not found\n", classno );
  74. exit(-1);
  75. }
  76. size_t count_p = iclassno->second.size();
  77. addStatistics ( iclassno->second, dimension, threshold, ones_p );
  78. double entropy_conditional_p = entropy ( ones_p, count_p - ones_p );
  79. size_t ones_n = 0;
  80. size_t count_n = 0;
  81. for ( LabeledSetVector::const_iterator i = ls.begin();
  82. i != ls.end();
  83. i++ )
  84. {
  85. if ( i->first != (int)classno );
  86. addStatistics ( i->second, dimension, threshold, ones_n );
  87. count_n += i->second.size();
  88. }
  89. double entropy_conditional_n = entropy ( ones_n, count_n - ones_n );
  90. double entropy_conditional = 0.5 * ( entropy_conditional_n + entropy_conditional_p );
  91. double entropy_joint = entropy ( ones_p + ones_n, count_p + count_n - ones_p - ones_n );
  92. return entropy_joint - entropy_conditional;
  93. }
  94. double MutualInformation::computeThresholdClass ( const LabeledSetVector & ls, size_t classno,
  95. size_t dimension, double & opt_threshold ) const
  96. {
  97. vector<double> thresholds;
  98. LOOP_ALL(ls)
  99. {
  100. EACH(classno, v);
  101. double val = v[dimension];
  102. thresholds.push_back ( val );
  103. }
  104. sort ( thresholds.begin(), thresholds.end() );
  105. thresholds.erase( std::unique( thresholds.begin(), thresholds.end()), thresholds.end());
  106. opt_threshold = 0.0;
  107. double opt_mi = 0.0;
  108. for ( vector<double>::const_iterator i = thresholds.begin();
  109. i != thresholds.end();
  110. i++ )
  111. {
  112. vector<double>::const_iterator j = i + 1;
  113. if ( j == thresholds.end() ) break;
  114. double threshold = 0.5 * ((*i) + (*j));
  115. double mi = mutualInformationClass ( ls, classno, dimension, threshold );
  116. if ( mi > opt_mi ) {
  117. opt_mi = mi;
  118. opt_threshold = threshold;
  119. }
  120. }
  121. return opt_mi;
  122. }
  123. double MutualInformation::computeThresholdOverall ( const LabeledSetVector & ls, size_t dimension, double & opt_threshold ) const
  124. {
  125. vector<double> thresholds;
  126. LOOP_ALL(ls)
  127. {
  128. EACH(classno, v);
  129. double val = v[dimension];
  130. thresholds.push_back ( val );
  131. }
  132. sort ( thresholds.begin(), thresholds.end() );
  133. thresholds.erase( std::unique( thresholds.begin(), thresholds.end()), thresholds.end());
  134. opt_threshold = 0.0;
  135. double opt_mi = 0.0;
  136. #ifdef DEBUGMUTUALINFORMATION
  137. vector<double> x;
  138. vector<double> y;
  139. #endif
  140. for ( vector<double>::const_iterator i = thresholds.begin();
  141. i != thresholds.end();
  142. i++ )
  143. {
  144. vector<double>::const_iterator j = i + 1;
  145. if ( j == thresholds.end() ) break;
  146. double threshold = 0.5 * ((*i) + (*j));
  147. double mi = mutualInformationOverall ( ls, dimension, threshold );
  148. #ifdef DEBUGMUTUALINFORMATION
  149. x.push_back ( threshold );
  150. y.push_back ( mi );
  151. #endif
  152. if ( mi > opt_mi ) {
  153. opt_mi = mi;
  154. opt_threshold = threshold;
  155. }
  156. }
  157. #ifdef DEBUGMUTUALINFORMATION
  158. if ( x.size() > 0 )
  159. Gnuplot gnu ( "Mutual Information", "smooth csplines", "threshold",
  160. "mi", x, y );
  161. #endif
  162. return opt_mi;
  163. }
  164. void MutualInformation::computeThresholdsClass ( const LabeledSetVector & ls, size_t classno,
  165. NICE::Vector & thresholds, NICE::Vector & mis ) const
  166. {
  167. size_t max_dimension = ls.dimension();
  168. thresholds.clear();
  169. mis.clear();
  170. for ( size_t k = 0 ; k < max_dimension ; k++ )
  171. {
  172. double t, mi;
  173. mi = computeThresholdClass ( ls, classno, k, t );
  174. mis.append(mi);
  175. thresholds.append(t);
  176. }
  177. }
  178. void MutualInformation::computeThresholdsOverall ( const LabeledSetVector & ls, NICE::Vector & thresholds, NICE::Vector & mis ) const
  179. {
  180. size_t max_dimension = ls.dimension();
  181. thresholds.clear();
  182. mis.clear();
  183. for ( size_t k = 0 ; k < max_dimension ; k++ )
  184. {
  185. double t, mi;
  186. mi = computeThresholdOverall ( ls, k, t );
  187. mis.append(mi);
  188. thresholds.append(t);
  189. }
  190. }