/** * @file SCGiniIndex.cpp * @brief the Gini index splitting criterion * @author Sven Sickert * @date 01/16/2017 */ #include "SCGiniIndex.h" using namespace OBJREC; /* default constructor */ SCGiniIndex::SCGiniIndex() : SplittingCriterion () { count_left = 0.0; count_right = 0.0; gini_left = 0.0; gini_right = 0.0; } /* simple constructor */ SCGiniIndex::SCGiniIndex( int _min_examples ) : SplittingCriterion ( _min_examples ) { count_left = 0.0; count_right = 0.0; gini_left = 0.0; gini_right = 0.0; } /* config constructor */ SCGiniIndex::SCGiniIndex( const NICE::Config *conf ) : SplittingCriterion ( conf ) { count_left = 0.0; count_right = 0.0; gini_left = 0.0; gini_right = 0.0; } /* copy constructor */ SCGiniIndex::SCGiniIndex( const SCGiniIndex &obj ) { min_examples = obj.min_examples; min_entropy = obj.min_entropy; min_purity = obj.min_purity; entropy_cur = obj.entropy_cur; count_left = obj.count_left; count_right = obj.count_right; gini_left = obj.gini_left; gini_right = obj.gini_right; } /* simple destructor */ SCGiniIndex::~SCGiniIndex() { } /* cloning function */ SplittingCriterion* SCGiniIndex::clone() { SplittingCriterion* sc = new SCGiniIndex( *this ); return sc; } double SCGiniIndex::computeGiniIndex( const double* distribution, const double count, const int maxClassNo ) { double g_sum = 0.0; for ( int j = 0 ; j <= maxClassNo ; j++ ) { double p = distribution[j] / count; g_sum += p*p; } return (1-g_sum); } bool SCGiniIndex::evaluateSplit( const FeatureValuesUnsorted & values, double threshold, double* distribution_left, double* distribution_right, int maxClassNo ) { this->count_left = 0; this->count_right = 0; int count_unweighted_left = 0; int count_unweighted_right = 0; double *distribution = new double [maxClassNo+1]; for ( int c = 0; c <= maxClassNo; c++ ) distribution[c] = 0.0; for ( FeatureValuesUnsorted::const_iterator i = values.begin(); i != values.end(); i++ ) { int classno = i->second; double value = i->first; double weight = i->fourth; distribution[classno] += weight; if ( value < threshold ) { distribution_left[classno] += weight; this->count_left += weight; count_unweighted_left++; } else { distribution_right[classno] += weight; this->count_right += weight; count_unweighted_right++; } } if ( (count_unweighted_left < this->min_examples) || (count_unweighted_right < this->min_examples) ) { delete [] distribution; return false; } // current entropy this->entropy_cur = computeEntropy( distribution, this->count_left+this->count_right, maxClassNo ); // left Gini index this->gini_left = computeGiniIndex( distribution_left, this->count_left, maxClassNo ); // right Gini index this->gini_right = computeGiniIndex( distribution_right, this->count_right, maxClassNo ); delete [] distribution; return true; } double SCGiniIndex::computePurity() const { double p_left = (this->count_left) / (this->count_left + this->count_right); // computing Gini impurity double gi = p_left*this->gini_left + (1-p_left)*this->gini_right; return (1-gi); }