瀏覽代碼

modified testSemSegObliqueTrees for the use of 3d images

Sven Sickert 9 年之前
父節點
當前提交
067c5cf547
共有 1 個文件被更改,包括 47 次插入17 次删除
  1. 47 17
      progs/testSemSegObliqueTrees.cpp

+ 47 - 17
progs/testSemSegObliqueTrees.cpp

@@ -66,6 +66,14 @@ int main ( int argc, char **argv )
     NICE::Timer timer;
     std::cout << "\nCLASSIFICATION" << std::endl;
     std::cout << "##############\n" << std::endl;
+
+    std::vector<int> zsizeVec;
+    bool run3Dseg = semseg->isMode3D();
+    SemSegTools::getDepthVector ( testFiles, zsizeVec, run3Dseg );
+    int depthCount, idx = 0;
+    std::vector<std::string> filelist;
+    NICE::MultiChannelImageT<int> segresult, gt;
+
     for (LabeledSet::const_iterator it = testFiles->begin(); it != testFiles->end(); it++)
     {
         for (std::vector<ImageInfo *>::const_iterator jt = it->second.begin();
@@ -73,8 +81,10 @@ int main ( int argc, char **argv )
         {
             ImageInfo & info = *(*jt);
             std::string file = info.img();
+            filelist.push_back(file);
+            depthCount++;
 
-            NICE::ImageT<int> segresult, gtruth;
+            NICE::ImageT<int> gtruth, res;
             if ( info.hasLocalizationInfo() )
             {
                 const LocalizationResult *l_gt = info.localization();
@@ -86,31 +96,51 @@ int main ( int argc, char **argv )
                 std::cerr << "testSemSegConvTrees: WARNING: NO localization info found for "
                           << file << std::endl;
             }
-            segresult = gtruth;
+            segresult.addChannel(gtruth);
+            gt.addChannel(gtruth);
+
+            int depthBoundary = 0;
+            if ( run3Dseg )
+                depthBoundary = zsizeVec[idx];
+
+            if ( depthCount < depthBoundary )
+                continue;
 
             // actual testing
-            NICE::MultiChannelImageT<double> probabilities;
+            NICE::MultiChannelImage3DT<double> probabilities;
 
             timer.start();
-            semseg->semanticseg( file, segresult, probabilities );
+            semseg->semanticseg( filelist, segresult, probabilities );
             timer.stop();
             std::cout << "Time for Classification: " << timer.getLastAbsolute()
                       << "\n\n";
 
             // updating confusion matrix
-            SemSegTools::updateConfusionMatrix ( segresult, gtruth, M,
-                forbiddenClasses, classMapping );
-
-            // saving results to image file
-            NICE::ColorImage rgb;
-            NICE::ColorImage rgb_gt;
-            NICE::ColorImage orig ( file );
-            classNames.labelToRGB( segresult, rgb);
-            classNames.labelToRGB( gtruth, rgb_gt);
-            std::string fname = NICE::StringTools::baseName ( file, false );
-            std::string outStr;
-            SemSegTools::saveResultsToImageFile( &conf, "analysis", orig,
-                        rgb_gt, rgb, fname, outStr );
+            res = gtruth;
+            for ( int z = 0; z < segresult.channels(); z++ )
+            {
+                for ( int y = 0; y < res.height(); y++ )
+                    for ( int x = 0; x < res.width(); x++)
+                    {
+                        res.setPixel ( x, y, segresult.get(x,y,(unsigned int)z) );
+                        if ( run3Dseg )
+                            gtruth.setPixel ( x, y, gt.get(x,y,(unsigned int)z) );
+                    }
+
+                SemSegTools::updateConfusionMatrix ( res, gtruth, M,
+                    forbiddenClasses, classMapping );
+
+                // saving results to image file
+                NICE::ColorImage rgb;
+                NICE::ColorImage rgb_gt;
+                NICE::ColorImage orig ( filelist[z] );
+                classNames.labelToRGB( res, rgb);
+                classNames.labelToRGB( gtruth, rgb_gt);
+                std::string fname = NICE::StringTools::baseName ( filelist[z], false );
+                std::string outStr;
+                SemSegTools::saveResultsToImageFile( &conf, "analysis", orig,
+                            rgb_gt, rgb, fname, outStr );
+            }
         }
     }