Bjoern Froehlich il y a 13 ans
Parent
commit
0ac719bcf0
3 fichiers modifiés avec 45 ajouts et 45 suppressions
  1. 35 39
      semseg/SemSegContextTree.cpp
  2. 10 0
      semseg/SemSegContextTree.h
  3. 0 6
      semseg/SemSegCsurka.cpp

+ 35 - 39
semseg/SemSegContextTree.cpp

@@ -94,12 +94,13 @@ public:
 SemSegContextTree::SemSegContextTree( const Config *conf, const MultiDataset *md )
 SemSegContextTree::SemSegContextTree( const Config *conf, const MultiDataset *md )
     : SemanticSegmentation ( conf, &(md->getClassNames("train")) )
     : SemanticSegmentation ( conf, &(md->getClassNames("train")) )
 {
 {
+	this->conf = conf;
 	string section = "SSContextTree";
 	string section = "SSContextTree";
 	lfcw = new LFColorWeijer(conf);
 	lfcw = new LFColorWeijer(conf);
 	
 	
 	grid = conf->gI(section, "grid", 10 );
 	grid = conf->gI(section, "grid", 10 );
 	
 	
-	maxSamples = conf->gI(section, "max_samples", 2000 );
+	maxSamples = conf->gI(section, "max_samples", 200);
 	
 	
 	minFeats = conf->gI(section, "min_feats", 50 );
 	minFeats = conf->gI(section, "min_feats", 50 );
 	
 	
@@ -116,6 +117,8 @@ SemSegContextTree::SemSegContextTree( const Config *conf, const MultiDataset *md
 	ops.push_back(new Addition());
 	ops.push_back(new Addition());
 	ops.push_back(new Only1());
 	ops.push_back(new Only1());
 	
 	
+	classnames = md->getClassNames ( "train" );
+	
 	///////////////////////////////////
 	///////////////////////////////////
 	// Train Segmentation Context Trees
 	// Train Segmentation Context Trees
 	//////////////////////////////////
 	//////////////////////////////////
@@ -173,7 +176,10 @@ void SemSegContextTree::getBestSplit(const vector<vector<vector<vector<double> >
 	vector<double> fraction(a.size(),0.0);
 	vector<double> fraction(a.size(),0.0);
 	for(uint i = 0; i < fraction.size(); i++)
 	for(uint i = 0; i < fraction.size(); i++)
 	{
 	{
-		fraction[i] = ((double)maxSamples)/((double)featcounter*a[i]*a.size());
+		if ( forbidden_classes.find ( labelmapback[i] ) != forbidden_classes.end() )
+			fraction[i] = 0;
+		else
+			fraction[i] = ((double)maxSamples)/((double)featcounter*a[i]*a.size());
 		//cout << "fraction["<<i<<"]: "<< fraction[i] << " a[" << i << "]: " << a[i] << endl;
 		//cout << "fraction["<<i<<"]: "<< fraction[i] << " a[" << i << "]: " << a[i] << endl;
 	}
 	}
 	//cout << "a.size(): " << a.size() << endl;
 	//cout << "a.size(): " << a.size() << endl;
@@ -358,6 +364,16 @@ void SemSegContextTree::train ( const MultiDataset *md )
 	vector<vector<vector<int> > > currentfeats;
 	vector<vector<vector<int> > > currentfeats;
 	vector<vector<vector<int> > > labels;
 	vector<vector<vector<int> > > labels;
 	
 	
+	forbidden_classes;
+
+	std::string forbidden_classes_s = conf->gS ( "analysis", "donttrain", "" );
+	if ( forbidden_classes_s == "" )
+	{
+		forbidden_classes_s = conf->gS ( "analysis", "forbidden_classes", "" );
+	}
+	
+	classnames.getSelection ( forbidden_classes_s, forbidden_classes );
+	
 	int imgcounter = 0;
 	int imgcounter = 0;
 	
 	
 	LOOP_ALL_S ( *trainp )
 	LOOP_ALL_S ( *trainp )
@@ -426,9 +442,10 @@ void SemSegContextTree::train ( const MultiDataset *md )
 			{
 			{
 				classno = pixelLabels.getPixel(x, y);
 				classno = pixelLabels.getPixel(x, y);
 				labels[imgcounter][x][y] = classno;
 				labels[imgcounter][x][y] = classno;
+				if ( forbidden_classes.find ( classno ) != forbidden_classes.end() )
+					continue;
 				labelcounter[classno]++;
 				labelcounter[classno]++;
-				//if ( forbidden_classes.find ( classno ) != forbidden_classes.end() )
-					//continue;
+				
 			}
 			}
 		}
 		}
 		imgcounter++;
 		imgcounter++;
@@ -437,42 +454,12 @@ void SemSegContextTree::train ( const MultiDataset *md )
 	}
 	}
 	pb.hide();
 	pb.hide();
 	
 	
-	/*int opsize = (int)ops.size();
-	int featdim = (int)allfeats[0][0][0].size();
-	
-
-	for(int x1 = -windowSize/2; x1 < windowSize/2+1; x1++)
-	{
-		for(int y1 = -windowSize/2; y1 < windowSize/2+1; y1++)
-		{
-			for(int x2 = -windowSize/2; x2 < windowSize/2+1; x2++)
-			{
-				for(int y2 = -windowSize/2; y2 < windowSize/2+1; y2++)
-				{
-					for(int f = 0; f < featdim; f++)
-					{
-						for(int o = 0; o < opsize; o++)
-						{
-							vector<int> tmp(6,0);
-							tmp[0] = x1;
-							tmp[1] = y1;
-							tmp[2] = x2;
-							tmp[3] = y2;
-							tmp[4] = f;
-							tmp[5] = o;
-							featsel.push_back(tmp);
-						}
-					}
-				}
-			}
-		}
-	}*/
-	
 	map<int,int>::iterator mapit;
 	map<int,int>::iterator mapit;
 	int classes = 0;
 	int classes = 0;
 	for(mapit = labelcounter.begin(); mapit != labelcounter.end(); mapit++)
 	for(mapit = labelcounter.begin(); mapit != labelcounter.end(); mapit++)
 	{
 	{
 		labelmap[mapit->first] = classes;
 		labelmap[mapit->first] = classes;
+		
 		labelmapback[classes] = mapit->first;
 		labelmapback[classes] = mapit->first;
 		classes++;
 		classes++;
 	}
 	}
@@ -505,6 +492,7 @@ void SemSegContextTree::train ( const MultiDataset *md )
 	{
 	{
 		cout << "a["<<i<<"]: " << a[i] << endl;
 		cout << "a["<<i<<"]: " << a[i] << endl;
 	}
 	}
+	cout << "a.size: " << a.size() << endl;
 #endif
 #endif
 	
 	
 	tree.push_back(Node());
 	tree.push_back(Node());
@@ -577,10 +565,18 @@ void SemSegContextTree::train ( const MultiDataset *md )
 					double lcounter = 0.0, rcounter = 0.0;
 					double lcounter = 0.0, rcounter = 0.0;
 					for(uint d = 0; d < tree[left].dist.size(); d++)
 					for(uint d = 0; d < tree[left].dist.size(); d++)
 					{
 					{
-						tree[left].dist[d]/=a[d];
-						lcounter +=tree[left].dist[d];
-						tree[right].dist[d]/=a[d];
-						rcounter +=tree[right].dist[d];
+						if ( forbidden_classes.find ( labelmapback[d] ) != forbidden_classes.end() )
+						{
+							tree[left].dist[d] = 0;
+							tree[right].dist[d] = 0;
+						}
+						else
+						{
+							tree[left].dist[d]/=a[d];
+							lcounter +=tree[left].dist[d];
+							tree[right].dist[d]/=a[d];
+							rcounter +=tree[right].dist[d];
+						}
 					}
 					}
 					if(lcounter <= 0 || rcounter <= 0)
 					if(lcounter <= 0 || rcounter <= 0)
 					{
 					{

+ 10 - 0
semseg/SemSegContextTree.h

@@ -122,6 +122,16 @@ class SemSegContextTree : public SemanticSegmentation
 	/** use alternative calculation for information gain */
 	/** use alternative calculation for information gain */
 	bool useShannonEntropy;
 	bool useShannonEntropy;
 	
 	
+	/** Classnames */
+	ClassNames classnames;
+
+	/** train selection */
+	set<int> forbidden_classes;
+	
+	/** Configfile */
+	const Config *conf;
+
+	
     public:
     public:
 	/** simple constructor */
 	/** simple constructor */
 	SemSegContextTree( const Config *conf, const MultiDataset *md );
 	SemSegContextTree( const Config *conf, const MultiDataset *md );

+ 0 - 6
semseg/SemSegCsurka.cpp

@@ -1,9 +1,3 @@
-/**
- * @file SemSegCsurka.cpp
- * @brief semantic segmentation using the method from Csurka08
- * @author Björn Fröhlich
- * @date 04/24/2009
- */
 #include <iostream>
 #include <iostream>
 
 
 #include "SemSegCsurka.h"
 #include "SemSegCsurka.h"