|
@@ -94,12 +94,13 @@ public:
|
|
SemSegContextTree::SemSegContextTree( const Config *conf, const MultiDataset *md )
|
|
SemSegContextTree::SemSegContextTree( const Config *conf, const MultiDataset *md )
|
|
: SemanticSegmentation ( conf, &(md->getClassNames("train")) )
|
|
: SemanticSegmentation ( conf, &(md->getClassNames("train")) )
|
|
{
|
|
{
|
|
|
|
+ this->conf = conf;
|
|
string section = "SSContextTree";
|
|
string section = "SSContextTree";
|
|
lfcw = new LFColorWeijer(conf);
|
|
lfcw = new LFColorWeijer(conf);
|
|
|
|
|
|
grid = conf->gI(section, "grid", 10 );
|
|
grid = conf->gI(section, "grid", 10 );
|
|
|
|
|
|
- maxSamples = conf->gI(section, "max_samples", 2000 );
|
|
|
|
|
|
+ maxSamples = conf->gI(section, "max_samples", 200);
|
|
|
|
|
|
minFeats = conf->gI(section, "min_feats", 50 );
|
|
minFeats = conf->gI(section, "min_feats", 50 );
|
|
|
|
|
|
@@ -116,6 +117,8 @@ SemSegContextTree::SemSegContextTree( const Config *conf, const MultiDataset *md
|
|
ops.push_back(new Addition());
|
|
ops.push_back(new Addition());
|
|
ops.push_back(new Only1());
|
|
ops.push_back(new Only1());
|
|
|
|
|
|
|
|
+ classnames = md->getClassNames ( "train" );
|
|
|
|
+
|
|
///////////////////////////////////
|
|
///////////////////////////////////
|
|
// Train Segmentation Context Trees
|
|
// Train Segmentation Context Trees
|
|
//////////////////////////////////
|
|
//////////////////////////////////
|
|
@@ -173,7 +176,10 @@ void SemSegContextTree::getBestSplit(const vector<vector<vector<vector<double> >
|
|
vector<double> fraction(a.size(),0.0);
|
|
vector<double> fraction(a.size(),0.0);
|
|
for(uint i = 0; i < fraction.size(); i++)
|
|
for(uint i = 0; i < fraction.size(); i++)
|
|
{
|
|
{
|
|
- fraction[i] = ((double)maxSamples)/((double)featcounter*a[i]*a.size());
|
|
|
|
|
|
+ if ( forbidden_classes.find ( labelmapback[i] ) != forbidden_classes.end() )
|
|
|
|
+ fraction[i] = 0;
|
|
|
|
+ else
|
|
|
|
+ fraction[i] = ((double)maxSamples)/((double)featcounter*a[i]*a.size());
|
|
//cout << "fraction["<<i<<"]: "<< fraction[i] << " a[" << i << "]: " << a[i] << endl;
|
|
//cout << "fraction["<<i<<"]: "<< fraction[i] << " a[" << i << "]: " << a[i] << endl;
|
|
}
|
|
}
|
|
//cout << "a.size(): " << a.size() << endl;
|
|
//cout << "a.size(): " << a.size() << endl;
|
|
@@ -358,6 +364,16 @@ void SemSegContextTree::train ( const MultiDataset *md )
|
|
vector<vector<vector<int> > > currentfeats;
|
|
vector<vector<vector<int> > > currentfeats;
|
|
vector<vector<vector<int> > > labels;
|
|
vector<vector<vector<int> > > labels;
|
|
|
|
|
|
|
|
+ forbidden_classes;
|
|
|
|
+
|
|
|
|
+ std::string forbidden_classes_s = conf->gS ( "analysis", "donttrain", "" );
|
|
|
|
+ if ( forbidden_classes_s == "" )
|
|
|
|
+ {
|
|
|
|
+ forbidden_classes_s = conf->gS ( "analysis", "forbidden_classes", "" );
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ classnames.getSelection ( forbidden_classes_s, forbidden_classes );
|
|
|
|
+
|
|
int imgcounter = 0;
|
|
int imgcounter = 0;
|
|
|
|
|
|
LOOP_ALL_S ( *trainp )
|
|
LOOP_ALL_S ( *trainp )
|
|
@@ -426,9 +442,10 @@ void SemSegContextTree::train ( const MultiDataset *md )
|
|
{
|
|
{
|
|
classno = pixelLabels.getPixel(x, y);
|
|
classno = pixelLabels.getPixel(x, y);
|
|
labels[imgcounter][x][y] = classno;
|
|
labels[imgcounter][x][y] = classno;
|
|
|
|
+ if ( forbidden_classes.find ( classno ) != forbidden_classes.end() )
|
|
|
|
+ continue;
|
|
labelcounter[classno]++;
|
|
labelcounter[classno]++;
|
|
- //if ( forbidden_classes.find ( classno ) != forbidden_classes.end() )
|
|
|
|
- //continue;
|
|
|
|
|
|
+
|
|
}
|
|
}
|
|
}
|
|
}
|
|
imgcounter++;
|
|
imgcounter++;
|
|
@@ -437,42 +454,12 @@ void SemSegContextTree::train ( const MultiDataset *md )
|
|
}
|
|
}
|
|
pb.hide();
|
|
pb.hide();
|
|
|
|
|
|
- /*int opsize = (int)ops.size();
|
|
|
|
- int featdim = (int)allfeats[0][0][0].size();
|
|
|
|
-
|
|
|
|
-
|
|
|
|
- for(int x1 = -windowSize/2; x1 < windowSize/2+1; x1++)
|
|
|
|
- {
|
|
|
|
- for(int y1 = -windowSize/2; y1 < windowSize/2+1; y1++)
|
|
|
|
- {
|
|
|
|
- for(int x2 = -windowSize/2; x2 < windowSize/2+1; x2++)
|
|
|
|
- {
|
|
|
|
- for(int y2 = -windowSize/2; y2 < windowSize/2+1; y2++)
|
|
|
|
- {
|
|
|
|
- for(int f = 0; f < featdim; f++)
|
|
|
|
- {
|
|
|
|
- for(int o = 0; o < opsize; o++)
|
|
|
|
- {
|
|
|
|
- vector<int> tmp(6,0);
|
|
|
|
- tmp[0] = x1;
|
|
|
|
- tmp[1] = y1;
|
|
|
|
- tmp[2] = x2;
|
|
|
|
- tmp[3] = y2;
|
|
|
|
- tmp[4] = f;
|
|
|
|
- tmp[5] = o;
|
|
|
|
- featsel.push_back(tmp);
|
|
|
|
- }
|
|
|
|
- }
|
|
|
|
- }
|
|
|
|
- }
|
|
|
|
- }
|
|
|
|
- }*/
|
|
|
|
-
|
|
|
|
map<int,int>::iterator mapit;
|
|
map<int,int>::iterator mapit;
|
|
int classes = 0;
|
|
int classes = 0;
|
|
for(mapit = labelcounter.begin(); mapit != labelcounter.end(); mapit++)
|
|
for(mapit = labelcounter.begin(); mapit != labelcounter.end(); mapit++)
|
|
{
|
|
{
|
|
labelmap[mapit->first] = classes;
|
|
labelmap[mapit->first] = classes;
|
|
|
|
+
|
|
labelmapback[classes] = mapit->first;
|
|
labelmapback[classes] = mapit->first;
|
|
classes++;
|
|
classes++;
|
|
}
|
|
}
|
|
@@ -505,6 +492,7 @@ void SemSegContextTree::train ( const MultiDataset *md )
|
|
{
|
|
{
|
|
cout << "a["<<i<<"]: " << a[i] << endl;
|
|
cout << "a["<<i<<"]: " << a[i] << endl;
|
|
}
|
|
}
|
|
|
|
+ cout << "a.size: " << a.size() << endl;
|
|
#endif
|
|
#endif
|
|
|
|
|
|
tree.push_back(Node());
|
|
tree.push_back(Node());
|
|
@@ -577,10 +565,18 @@ void SemSegContextTree::train ( const MultiDataset *md )
|
|
double lcounter = 0.0, rcounter = 0.0;
|
|
double lcounter = 0.0, rcounter = 0.0;
|
|
for(uint d = 0; d < tree[left].dist.size(); d++)
|
|
for(uint d = 0; d < tree[left].dist.size(); d++)
|
|
{
|
|
{
|
|
- tree[left].dist[d]/=a[d];
|
|
|
|
- lcounter +=tree[left].dist[d];
|
|
|
|
- tree[right].dist[d]/=a[d];
|
|
|
|
- rcounter +=tree[right].dist[d];
|
|
|
|
|
|
+ if ( forbidden_classes.find ( labelmapback[d] ) != forbidden_classes.end() )
|
|
|
|
+ {
|
|
|
|
+ tree[left].dist[d] = 0;
|
|
|
|
+ tree[right].dist[d] = 0;
|
|
|
|
+ }
|
|
|
|
+ else
|
|
|
|
+ {
|
|
|
|
+ tree[left].dist[d]/=a[d];
|
|
|
|
+ lcounter +=tree[left].dist[d];
|
|
|
|
+ tree[right].dist[d]/=a[d];
|
|
|
|
+ rcounter +=tree[right].dist[d];
|
|
|
|
+ }
|
|
}
|
|
}
|
|
if(lcounter <= 0 || rcounter <= 0)
|
|
if(lcounter <= 0 || rcounter <= 0)
|
|
{
|
|
{
|