SCGiniIndex.cpp 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. /**
  2. * @file SCGiniIndex.cpp
  3. * @brief the Gini index splitting criterion
  4. * @author Sven Sickert
  5. * @date 01/16/2017
  6. */
  7. #include "SCGiniIndex.h"
  8. using namespace OBJREC;
  9. /* default constructor */
  10. SCGiniIndex::SCGiniIndex()
  11. : SplittingCriterion ()
  12. {
  13. count_left = 0.0;
  14. count_right = 0.0;
  15. gini_left = 0.0;
  16. gini_right = 0.0;
  17. }
  18. /* simple constructor */
  19. SCGiniIndex::SCGiniIndex( int _min_examples )
  20. : SplittingCriterion ( _min_examples )
  21. {
  22. count_left = 0.0;
  23. count_right = 0.0;
  24. gini_left = 0.0;
  25. gini_right = 0.0;
  26. }
  27. /* config constructor */
  28. SCGiniIndex::SCGiniIndex( const NICE::Config *conf )
  29. : SplittingCriterion ( conf )
  30. {
  31. count_left = 0.0;
  32. count_right = 0.0;
  33. gini_left = 0.0;
  34. gini_right = 0.0;
  35. }
  36. /* copy constructor */
  37. SCGiniIndex::SCGiniIndex( const SCGiniIndex &obj )
  38. {
  39. min_examples = obj.min_examples;
  40. min_entropy = obj.min_entropy;
  41. min_purity = obj.min_purity;
  42. entropy_cur = obj.entropy_cur;
  43. count_left = obj.count_left;
  44. count_right = obj.count_right;
  45. gini_left = obj.gini_left;
  46. gini_right = obj.gini_right;
  47. }
  48. /* simple destructor */
  49. SCGiniIndex::~SCGiniIndex()
  50. {
  51. }
  52. /* cloning function */
  53. SplittingCriterion* SCGiniIndex::clone()
  54. {
  55. SplittingCriterion* sc = new SCGiniIndex( *this );
  56. return sc;
  57. }
  58. double SCGiniIndex::computeGiniIndex(
  59. const double* distribution,
  60. const double count,
  61. const int maxClassNo )
  62. {
  63. double g_sum = 0.0;
  64. for ( int j = 0 ; j <= maxClassNo ; j++ )
  65. {
  66. double p = distribution[j] / count;
  67. g_sum += p*p;
  68. }
  69. return (1-g_sum);
  70. }
  71. bool SCGiniIndex::evaluateSplit(
  72. const FeatureValuesUnsorted & values,
  73. double threshold,
  74. double* distribution_left,
  75. double* distribution_right,
  76. int maxClassNo )
  77. {
  78. this->count_left = 0;
  79. this->count_right = 0;
  80. int count_unweighted_left = 0;
  81. int count_unweighted_right = 0;
  82. double *distribution = new double [maxClassNo+1];
  83. for ( int c = 0; c <= maxClassNo; c++ )
  84. distribution[c] = 0.0;
  85. for ( FeatureValuesUnsorted::const_iterator i = values.begin();
  86. i != values.end();
  87. i++ )
  88. {
  89. int classno = i->second;
  90. double value = i->first;
  91. double weight = i->fourth;
  92. distribution[classno] += weight;
  93. if ( value < threshold ) {
  94. distribution_left[classno] += weight;
  95. this->count_left += weight;
  96. count_unweighted_left++;
  97. }
  98. else
  99. {
  100. distribution_right[classno] += weight;
  101. this->count_right += weight;
  102. count_unweighted_right++;
  103. }
  104. }
  105. if ( (count_unweighted_left < this->min_examples)
  106. || (count_unweighted_right < this->min_examples) )
  107. {
  108. delete [] distribution;
  109. return false;
  110. }
  111. // current entropy
  112. this->entropy_cur = computeEntropy( distribution, this->count_left+this->count_right, maxClassNo );
  113. // left Gini index
  114. this->gini_left = computeGiniIndex( distribution_left, this->count_left, maxClassNo );
  115. // right Gini index
  116. this->gini_right = computeGiniIndex( distribution_right, this->count_right, maxClassNo );
  117. delete [] distribution;
  118. return true;
  119. }
  120. double SCGiniIndex::computePurity() const
  121. {
  122. double p_left = (this->count_left) / (this->count_left + this->count_right);
  123. // computing Gini impurity
  124. double gi = p_left*this->gini_left + (1-p_left)*this->gini_right;
  125. return (1-gi);
  126. }