|
@@ -17,8 +17,9 @@ using namespace NICE;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
-MutualInformation::MutualInformation()
|
|
|
|
|
|
+MutualInformation::MutualInformation( bool verbose )
|
|
{
|
|
{
|
|
|
|
+ this->verbose = verbose;
|
|
}
|
|
}
|
|
|
|
|
|
MutualInformation::~MutualInformation()
|
|
MutualInformation::~MutualInformation()
|
|
@@ -155,47 +156,41 @@ double MutualInformation::computeThresholdClass ( const LabeledSetVector & ls, s
|
|
double MutualInformation::computeThresholdOverall ( const LabeledSetVector & ls, size_t dimension, double & opt_threshold ) const
|
|
double MutualInformation::computeThresholdOverall ( const LabeledSetVector & ls, size_t dimension, double & opt_threshold ) const
|
|
{
|
|
{
|
|
vector<double> thresholds;
|
|
vector<double> thresholds;
|
|
|
|
+ vector<int> y;
|
|
LOOP_ALL(ls)
|
|
LOOP_ALL(ls)
|
|
{
|
|
{
|
|
EACH(classno, v);
|
|
EACH(classno, v);
|
|
double val = v[dimension];
|
|
double val = v[dimension];
|
|
thresholds.push_back ( val );
|
|
thresholds.push_back ( val );
|
|
|
|
+ y.push_back(classno);
|
|
}
|
|
}
|
|
sort ( thresholds.begin(), thresholds.end() );
|
|
sort ( thresholds.begin(), thresholds.end() );
|
|
thresholds.erase( std::unique( thresholds.begin(), thresholds.end()), thresholds.end());
|
|
thresholds.erase( std::unique( thresholds.begin(), thresholds.end()), thresholds.end());
|
|
|
|
|
|
opt_threshold = 0.0;
|
|
opt_threshold = 0.0;
|
|
double opt_mi = 0.0;
|
|
double opt_mi = 0.0;
|
|
-#ifdef DEBUGMUTUALINFORMATION
|
|
|
|
- vector<double> x;
|
|
|
|
- vector<double> y;
|
|
|
|
-#endif
|
|
|
|
|
|
|
|
|
|
+ uint ind = 0;
|
|
for ( vector<double>::const_iterator i = thresholds.begin();
|
|
for ( vector<double>::const_iterator i = thresholds.begin();
|
|
- i != thresholds.end();
|
|
|
|
- i++ )
|
|
|
|
|
|
+ i != thresholds.end(); i++, ind++ )
|
|
{
|
|
{
|
|
vector<double>::const_iterator j = i + 1;
|
|
vector<double>::const_iterator j = i + 1;
|
|
if ( j == thresholds.end() ) break;
|
|
if ( j == thresholds.end() ) break;
|
|
|
|
|
|
|
|
+ // the optimimum can not be found at non-class borders
|
|
|
|
+ if ( y[ind] == y[ind+1] ) continue;
|
|
|
|
+
|
|
double threshold = 0.5 * ((*i) + (*j));
|
|
double threshold = 0.5 * ((*i) + (*j));
|
|
|
|
+
|
|
|
|
+ // FIXME: This call is pretty inefficient!!
|
|
|
|
+ // We can directly count the features here...might be 100times faster :)
|
|
double mi = mutualInformationOverall ( ls, dimension, threshold );
|
|
double mi = mutualInformationOverall ( ls, dimension, threshold );
|
|
-#ifdef DEBUGMUTUALINFORMATION
|
|
|
|
- x.push_back ( threshold );
|
|
|
|
- y.push_back ( mi );
|
|
|
|
-#endif
|
|
|
|
if ( mi > opt_mi ) {
|
|
if ( mi > opt_mi ) {
|
|
opt_mi = mi;
|
|
opt_mi = mi;
|
|
opt_threshold = threshold;
|
|
opt_threshold = threshold;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
-#ifdef DEBUGMUTUALINFORMATION
|
|
|
|
- if ( x.size() > 0 )
|
|
|
|
- Gnuplot gnu ( "Mutual Information", "smooth csplines", "threshold",
|
|
|
|
- "mi", x, y );
|
|
|
|
-#endif
|
|
|
|
-
|
|
|
|
return opt_mi;
|
|
return opt_mi;
|
|
}
|
|
}
|
|
|
|
|
|
@@ -204,14 +199,15 @@ void MutualInformation::computeThresholdsClass ( const LabeledSetVector & ls, si
|
|
{
|
|
{
|
|
size_t max_dimension = ls.dimension();
|
|
size_t max_dimension = ls.dimension();
|
|
|
|
|
|
- thresholds.clear();
|
|
|
|
- mis.clear();
|
|
|
|
|
|
+ thresholds.resize(max_dimension);
|
|
|
|
+ mis.resize(max_dimension);
|
|
|
|
+
|
|
for ( size_t k = 0 ; k < max_dimension ; k++ )
|
|
for ( size_t k = 0 ; k < max_dimension ; k++ )
|
|
{
|
|
{
|
|
double t, mi;
|
|
double t, mi;
|
|
mi = computeThresholdClass ( ls, classno, k, t );
|
|
mi = computeThresholdClass ( ls, classno, k, t );
|
|
- mis.append(mi);
|
|
|
|
- thresholds.append(t);
|
|
|
|
|
|
+ mis[k] = mi;
|
|
|
|
+ thresholds[k] = t;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
@@ -219,13 +215,15 @@ void MutualInformation::computeThresholdsOverall ( const LabeledSetVector & ls,
|
|
{
|
|
{
|
|
size_t max_dimension = ls.dimension();
|
|
size_t max_dimension = ls.dimension();
|
|
|
|
|
|
- thresholds.clear();
|
|
|
|
- mis.clear();
|
|
|
|
|
|
+ thresholds.resize(max_dimension);
|
|
|
|
+ mis.resize(max_dimension);
|
|
for ( size_t k = 0 ; k < max_dimension ; k++ )
|
|
for ( size_t k = 0 ; k < max_dimension ; k++ )
|
|
{
|
|
{
|
|
|
|
+ if ( verbose )
|
|
|
|
+ cerr << "MutualInformation: Optimizing threshold for feature " << k << " / " << max_dimension << endl;
|
|
double t, mi;
|
|
double t, mi;
|
|
mi = computeThresholdOverall ( ls, k, t );
|
|
mi = computeThresholdOverall ( ls, k, t );
|
|
- mis.append(mi);
|
|
|
|
- thresholds.append(t);
|
|
|
|
|
|
+ mis[k] = mi;
|
|
|
|
+ thresholds[k] = t;
|
|
}
|
|
}
|
|
}
|
|
}
|