|
@@ -1,25 +1,25 @@
|
|
|
#include "SemSegContextTree3D.h"
|
|
|
|
|
|
-#include "core/basics/FileName.h"
|
|
|
-#include "core/basics/numerictools.h"
|
|
|
-#include "core/basics/quadruplet.h"
|
|
|
-#include "core/basics/StringTools.h"
|
|
|
-#include "core/basics/Timer.h"
|
|
|
-#include "core/basics/vectorio.h"
|
|
|
-#include "core/image/Filter.h"
|
|
|
-#include "core/image/FilterT.h"
|
|
|
-#include "core/image/Morph.h"
|
|
|
-#include "core/imagedisplay/ImageDisplay.h"
|
|
|
-
|
|
|
-#include "vislearning/baselib/cc.h"
|
|
|
-#include "vislearning/baselib/Globals.h"
|
|
|
-#include "vislearning/baselib/ICETools.h"
|
|
|
-#include "vislearning/cbaselib/CachedExample.h"
|
|
|
-#include "vislearning/cbaselib/PascalResults.h"
|
|
|
-
|
|
|
-#include "segmentation/RSGraphBased.h"
|
|
|
-#include "segmentation/RSMeanShift.h"
|
|
|
-#include "segmentation/RSSlic.h"
|
|
|
+#include <core/basics/FileName.h>
|
|
|
+#include <core/basics/numerictools.h>
|
|
|
+#include <core/basics/quadruplet.h>
|
|
|
+#include <core/basics/StringTools.h>
|
|
|
+#include <core/basics/Timer.h>
|
|
|
+#include <core/basics/vectorio.h>
|
|
|
+#include <core/image/Filter.h>
|
|
|
+#include <core/image/FilterT.h>
|
|
|
+#include <core/image/Morph.h>
|
|
|
+#include <core/imagedisplay/ImageDisplay.h>
|
|
|
+
|
|
|
+#include <vislearning/baselib/cc.h>
|
|
|
+#include <vislearning/baselib/Globals.h>
|
|
|
+#include <vislearning/baselib/ICETools.h>
|
|
|
+#include <vislearning/cbaselib/CachedExample.h>
|
|
|
+#include <vislearning/cbaselib/PascalResults.h>
|
|
|
+
|
|
|
+#include <segmentation/RSGraphBased.h>
|
|
|
+#include <segmentation/RSMeanShift.h>
|
|
|
+#include <segmentation/RSSlic.h>
|
|
|
|
|
|
#include <omp.h>
|
|
|
#include <iostream>
|
|
@@ -74,8 +74,8 @@ SemSegContextTree3D::SemSegContextTree3D () : SemanticSegmentation ()
|
|
|
|
|
|
SemSegContextTree3D::SemSegContextTree3D (
|
|
|
const Config *conf,
|
|
|
- const MultiDataset *md )
|
|
|
- : SemanticSegmentation ( conf, & ( md->getClassNames ( "train" ) ) )
|
|
|
+ const ClassNames *classNames )
|
|
|
+ : SemanticSegmentation ( conf, classNames )
|
|
|
{
|
|
|
this->conf = conf;
|
|
|
|
|
@@ -117,7 +117,7 @@ SemSegContextTree3D::SemSegContextTree3D (
|
|
|
if ( useWeijer )
|
|
|
this->lfcw = new LocalFeatureColorWeijer ( conf );
|
|
|
|
|
|
- this->classnames = md->getClassNames ( "train" );
|
|
|
+ this->classnames = (*classNames);
|
|
|
|
|
|
// feature types
|
|
|
this->useFeat0 = conf->gB ( section, "use_feat_0", true); // pixel pair features
|
|
@@ -185,6 +185,13 @@ void SemSegContextTree3D::initOperations()
|
|
|
o->setContext(true);
|
|
|
tops3.push_back ( o );
|
|
|
}
|
|
|
+ if ( conf->gB ( featsec, "bi_int", true ) )
|
|
|
+ {
|
|
|
+ tops2.push_back ( new BiIntegralOps3D() );
|
|
|
+ Operation3D* o = new BiIntegralOps3D();
|
|
|
+ o->setContext(true);
|
|
|
+ tops3.push_back ( o );
|
|
|
+ }
|
|
|
if ( conf->gB ( featsec, "bi_int_cent", true ) )
|
|
|
{
|
|
|
tops2.push_back ( new BiIntegralCenteredOps3D() );
|
|
@@ -452,11 +459,11 @@ double SemSegContextTree3D::getBestSplit (
|
|
|
z2 = ( int ) ( rand() % 8 );
|
|
|
}
|
|
|
|
|
|
-// if (conf->gB ( "SSContextTree", "z_negative_only", false ))
|
|
|
-// {
|
|
|
-// z1 = -abs(z1);
|
|
|
-// z2 = -abs(z2);
|
|
|
-// }
|
|
|
+ if (conf->gB ( "SSContextTree", "z_negative_only", false ))
|
|
|
+ {
|
|
|
+ z1 = abs(z1);
|
|
|
+ z2 = abs(z2);
|
|
|
+ }
|
|
|
|
|
|
/* random feature maps (channels) */
|
|
|
int f1, f2;
|
|
@@ -479,6 +486,7 @@ double SemSegContextTree3D::getBestSplit (
|
|
|
|
|
|
Operation3D *op = ops[ft][o]->clone();
|
|
|
op->set ( x1, y1, z1, x2, y2, z2, f1, f2, ft );
|
|
|
+ op->setWSize(windowSize);
|
|
|
|
|
|
if ( ft == 3 || ft == 4 )
|
|
|
op->setContext ( true );
|
|
@@ -761,9 +769,8 @@ inline double computeWeight ( const int &d, const int &dim )
|
|
|
|
|
|
void SemSegContextTree3D::train ( const MultiDataset *md )
|
|
|
{
|
|
|
- const LabeledSet trainSet = * ( *md ) ["train"];
|
|
|
- const LabeledSet *trainp = &trainSet;
|
|
|
-
|
|
|
+ const LabeledSet *trainp = ( *md ) ["train"];
|
|
|
+
|
|
|
if ( saveLoadData )
|
|
|
{
|
|
|
if ( FileMgt::fileExists ( fileLocation ) )
|