Эх сурвалжийг харах

added constraints to the first stage of the n-stage algorithm

Helge Wrede 9 жил өмнө
parent
commit
0d908e87b0

+ 15 - 6
algo/NStage.cpp

@@ -11,18 +11,20 @@ namespace algo
     NStage::NStage(std::vector<size_t> max_frame_skip,
                    std::vector<double> penalty_value,
                    std::vector<size_t> max_tracklet_count,
-                   double edge_weight_threshold)
+                   double edge_weight_threshold,
+                   std::unordered_map<std::string, double> constraints)
     {
         max_frame_skips_ = max_frame_skip;
         penalty_values_ = penalty_value;
         max_tracklet_counts_ = max_tracklet_count;
         iterations_ = std::min(max_tracklet_count.size(), penalty_value.size());
         edge_weight_threshold_ = edge_weight_threshold;
+        constraints_ = constraints;
     }
 
-    void NStage::CreateObjectGraph(DirectedGraph & graph, core::DetectionSequence & detections)
+    void NStage::CreateObjectGraph(DirectedGraph & graph,
+                                   core::DetectionSequence & detections)
     {
-        //TODO constraints to only create necessary
         util::Logger::LogInfo("Creating object graph");
 
         std::vector<std::vector<Vertex>> layers;
@@ -69,9 +71,16 @@ namespace algo
                     {
                         Vertex v = layers[i + k][l];
 
-                        double weight = values[u]->CompareTo(values[v]);
-                        if (weight < edge_weight_threshold_)
-                            boost::add_edge(u, v, weight, graph);
+                        // Only create the edge if the constraints are assured
+                        if (values[u]->IsWithinConstraints(values[v], constraints_))
+                        {
+                            double weight = values[u]->CompareTo(values[v]);
+
+                            if (weight < edge_weight_threshold_)
+                            {
+                                boost::add_edge(u, v, weight, graph);
+                            }
+                        }
                     }
                 }
 

+ 9 - 1
algo/NStage.h

@@ -43,11 +43,17 @@ namespace algo
          */
         double edge_weight_threshold_;
 
+        /**
+         * The constraints to ensure when creating edges
+         */
+        std::unordered_map<std::string, double> constraints_;
+
         /**
          * Creates a graph with vertices for every detected object
          *
          * @param graph The graph to write into
          * @param detections The objects to use for the graph
+         * @param constraints The constraints to assure when creating edges
          */
         void CreateObjectGraph(DirectedGraph & graph, core::DetectionSequence & detections);
 
@@ -82,11 +88,13 @@ namespace algo
          * @param max_tracklet_count The maximum number of tracklets to create
          * @param edge_weight_threshold The maximum weight an edge can have in the initial graph,
          *                              edges with higher weights are discarded
+         * @param constraints The constraints to ensure when creating edges
          */
         NStage(std::vector<size_t> max_frame_skip,
                std::vector<double> penalty_value,
                std::vector<size_t> max_tracklet_count,
-               double edge_weight_threshold);
+               double edge_weight_threshold,
+               std::unordered_map<std::string, double> constraints);
 
         /**
          * Runs the algorithm on the specified sequence and stores the found tracks into the

+ 34 - 0
core/ObjectData.cpp

@@ -8,6 +8,9 @@
 
 namespace core
 {
+    const std::string ObjectData::CONSTRAINT_FRAME_DIFFERENCE = "frame_difference";
+    const std::string ObjectData::CONSTRAINT_SCORE_DIFFERENCE = "score_difference";
+
     ObjectData::ObjectData()
             : frame_index_(0), is_virtual_(true), detection_score_(0.0)
     {
@@ -48,6 +51,37 @@ namespace core
         return 0.0;
     }
 
+    bool ObjectData::IsWithinConstraints(ObjectDataPtr obj,
+                                         std::unordered_map<std::string, double> & constraints)
+    const
+    {
+        if (constraints.count(CONSTRAINT_FRAME_DIFFERENCE) > 0)
+        {
+            size_t frame_difference;
+
+            // check the size to prevent negative values for unsigned type
+            if (frame_index_ > obj->GetFrameIndex()) {
+                frame_difference = frame_index_ - obj->GetFrameIndex();
+            } else {
+                frame_difference = obj->GetFrameIndex() - frame_index_;
+            }
+
+            if (frame_difference > constraints[CONSTRAINT_FRAME_DIFFERENCE]) {
+                return false;
+            }
+        }
+
+        if (constraints.count(CONSTRAINT_SCORE_DIFFERENCE) > 0) {
+            double score_difference = fabs(detection_score_ - obj->GetDetectionScore());
+
+            if (score_difference > constraints[CONSTRAINT_SCORE_DIFFERENCE]) {
+                return false;
+            }
+        }
+
+        return true;
+    }
+
     ObjectDataPtr ObjectData::Interpolate(ObjectDataPtr obj, double fraction) const
     {
         size_t index = static_cast<size_t>(

+ 17 - 3
core/ObjectData.h

@@ -2,8 +2,8 @@
 // Created by wrede on 19.04.16.
 //
 
-#ifndef GBMOT_NODEDATA_H
-#define GBMOT_NODEDATA_H
+#ifndef GBMOT_OBJECTDATA_H
+#define GBMOT_OBJECTDATA_H
 
 
 #include <string>
@@ -24,6 +24,9 @@ namespace core
     class ObjectData
     {
     private:
+        static const std::string CONSTRAINT_FRAME_DIFFERENCE;
+        static const std::string CONSTRAINT_SCORE_DIFFERENCE;
+
         /**
          * If this node is considered virtual
          */
@@ -94,6 +97,17 @@ namespace core
          */
         virtual double CompareTo(ObjectDataPtr obj) const;
 
+        /**
+         * Checks if the difference between this object and the specified object is within the
+         * constraints specified. The difference is calculated for each constraint separately.
+         *
+         * @param obj The object to get the difference to
+         * @param constraints The constraints to assure
+         */
+        virtual bool IsWithinConstraints(ObjectDataPtr obj,
+                                         std::unordered_map<std::string, double> & constraints)
+                const;
+
         /**
          * Linearly interpolates between this and the given object.
          * Creates a new object to fit between the two objects.
@@ -134,4 +148,4 @@ namespace core
 }
 
 
-#endif //GBMOT_NODEDATA_H
+#endif //GBMOT_OBJECTDATA_H

+ 21 - 0
core/ObjectData2D.cpp

@@ -7,6 +7,8 @@
 
 namespace core
 {
+    const std::string ObjectData2D::CONSTRAINT_DISTANCE_EUCLID = "distance_euclid";
+
     ObjectData2D::ObjectData2D(size_t frame_index, cv::Point2d position)
             : ObjectData(frame_index),
               position_(position),
@@ -50,6 +52,25 @@ namespace core
         return d_temp * temporal_weight_ + d_spat * spatial_weight_;
     }
 
+    bool ObjectData2D::IsWithinConstraints(ObjectDataPtr obj,
+                                           std::unordered_map<std::string, double> & constraints)
+    const
+    {
+        if (!ObjectData::IsWithinConstraints(obj, constraints))
+            return false;
+
+        ObjectData2DPtr obj_2d = std::static_pointer_cast<ObjectData2D>(obj);
+
+        if (constraints.count(CONSTRAINT_DISTANCE_EUCLID) > 0) {
+            double distance_euclid = util::MyMath::EuclideanDistance(position_, obj_2d->position_);
+
+            if (distance_euclid > constraints[CONSTRAINT_DISTANCE_EUCLID])
+                return false;
+        }
+
+        return true;
+    }
+
     ObjectDataPtr ObjectData2D::Interpolate(ObjectDataPtr obj,
                                             double fraction) const
     {

+ 5 - 0
core/ObjectData2D.h

@@ -18,6 +18,8 @@ namespace core
     class ObjectData2D : public ObjectData
     {
     private:
+        static const std::string CONSTRAINT_DISTANCE_EUCLID;
+
         /**
          * The position in the two dimensional space
          */
@@ -75,6 +77,9 @@ namespace core
         double GetSpatialWeight() const;
 
         virtual double CompareTo(ObjectDataPtr obj) const override;
+        virtual bool IsWithinConstraints(ObjectDataPtr obj,
+                                         std::unordered_map<std::string, double> & constraints)
+                const override;
         virtual ObjectDataPtr Interpolate(ObjectDataPtr obj, double fraction) const override;
         virtual void Visualize(cv::Mat& image, cv::Scalar& color) const override;
         virtual std::string ToString(char delimiter) const override;

+ 22 - 0
core/ObjectDataAngular.cpp

@@ -8,6 +8,8 @@
 
 namespace core
 {
+    const std::string ObjectDataAngular::CONSTRAINT_ANGULAR_DIFFERENCE = "angular_difference";
+
     ObjectDataAngular::ObjectDataAngular(size_t frame_index,
                                          const cv::Point2d& position,
                                          double angle)
@@ -55,6 +57,26 @@ namespace core
         return ObjectData2D::CompareTo(obj) + d_ang * angular_weight_;
     }
 
+    bool ObjectDataAngular::IsWithinConstraints(ObjectDataPtr obj,
+                                                std::unordered_map<std::string, double> & constraints) const
+    {
+        if (!ObjectData2D::IsWithinConstraints(obj, constraints))
+            return false;
+
+        ObjectDataAngularPtr obj_ang =
+                std::static_pointer_cast<ObjectDataAngular>(obj);
+
+        if (constraints.count(CONSTRAINT_ANGULAR_DIFFERENCE) > 0) {
+            double angular_difference = fabs(angle_ - obj_ang->angle_);
+
+            if (angular_difference > constraints[CONSTRAINT_ANGULAR_DIFFERENCE]) {
+                return false;
+            }
+        }
+
+        return true;
+    }
+
     ObjectDataPtr ObjectDataAngular::Interpolate(ObjectDataPtr obj,
                                                  double fraction) const
     {

+ 5 - 0
core/ObjectDataAngular.h

@@ -19,6 +19,8 @@ namespace core
     class ObjectDataAngular : public ObjectData2D
     {
     private:
+        static const std::string CONSTRAINT_ANGULAR_DIFFERENCE;
+
         /**
          * The rotation angle in radians
          */
@@ -73,6 +75,9 @@ namespace core
         double GetAngularWeight() const;
 
         virtual double CompareTo(ObjectDataPtr obj) const override;
+        virtual bool IsWithinConstraints(ObjectDataPtr obj,
+                                         std::unordered_map<std::string, double> & constraints)
+                const override;
         virtual ObjectDataPtr Interpolate(ObjectDataPtr obj, double fraction) const override;
         virtual void Visualize(cv::Mat& image, cv::Scalar& color) const override;
         virtual std::string ToString(char delimiter) const override;

+ 29 - 0
core/ObjectDataBox.cpp

@@ -8,6 +8,9 @@
 
 namespace core
 {
+    const std::string ObjectDataBox::CONSTRAINT_WIDTH_DIFFERENCE = "width_difference";
+    const std::string ObjectDataBox::CONSTRAINT_HEIGHT_DIFFERENCE = "height_difference";
+
     ObjectDataBox::ObjectDataBox(size_t frame_index, cv::Point2d center,
                                  cv::Point2d size)
             : ObjectData2D(frame_index, center),
@@ -38,6 +41,32 @@ namespace core
         return d_temp * GetTemporalWeight() + d_spat * GetSpatialWeight();
     }
 
+    bool ObjectDataBox::IsWithinConstraints(ObjectDataPtr obj,
+                                            std::unordered_map<std::string, double> & constraints)
+    const
+    {
+        if (!ObjectData2D::IsWithinConstraints(obj, constraints))
+            return false;
+
+        ObjectDataBoxPtr obj_box = std::static_pointer_cast<ObjectDataBox>(obj);
+
+        if (constraints.count(CONSTRAINT_WIDTH_DIFFERENCE) > 0) {
+            double width_difference = fabs(size_.x - obj_box->size_.x);
+
+            if (width_difference > constraints[CONSTRAINT_WIDTH_DIFFERENCE])
+                return false;
+        }
+
+        if (constraints.count(CONSTRAINT_HEIGHT_DIFFERENCE) > 0) {
+            double height_difference = fabs(size_.y - obj_box->size_.y);
+
+            if (height_difference > constraints[CONSTRAINT_HEIGHT_DIFFERENCE])
+                return false;
+        }
+
+        return true;
+    }
+
     ObjectDataPtr ObjectDataBox::Interpolate(ObjectDataPtr obj,
                                              double fraction) const
     {

+ 6 - 0
core/ObjectDataBox.h

@@ -18,6 +18,9 @@ namespace core
     class ObjectDataBox : public ObjectData2D
     {
     private:
+        static const std::string CONSTRAINT_WIDTH_DIFFERENCE;
+        static const std::string CONSTRAINT_HEIGHT_DIFFERENCE;
+
         /**
          * The size of the bounding box
          */
@@ -42,6 +45,9 @@ namespace core
         cv::Point2d GetSize() const;
 
         virtual double CompareTo(ObjectDataPtr obj) const override;
+        virtual bool IsWithinConstraints(ObjectDataPtr obj,
+                                         std::unordered_map<std::string, double> & constraints)
+                const override;
         virtual ObjectDataPtr Interpolate(ObjectDataPtr obj, double fraction) const override;
         virtual void Visualize(cv::Mat& image, cv::Scalar& color) const override;
         virtual std::string ToString(char delimiter) const override;

+ 17 - 2
main/main.cpp

@@ -21,6 +21,7 @@ struct
     std::string max_tracklet_count;
     std::string penalty_value;
     double edge_weight_threshold;
+    std::unordered_map<std::string, double> constraints;
 } n_stage_params;
 
 void RunNStage(core::DetectionSequence& sequence, std::vector<core::TrackletPtr>& tracks)
@@ -82,7 +83,7 @@ void RunNStage(core::DetectionSequence& sequence, std::vector<core::TrackletPtr>
 
     // Init n-stage
     algo::NStage n_stage(max_frame_skips, penalty_values, max_tracklet_counts,
-                         n_stage_params.edge_weight_threshold);
+                         n_stage_params.edge_weight_threshold, n_stage_params.constraints);
 
     n_stage.Run(sequence, tracks);
 
@@ -132,9 +133,11 @@ void RunBerclaz(core::DetectionSequence & sequence, std::vector<core::TrackletPt
 
 void Run(int argc, char const * const * argv)
 {
+    //TODO output info for all possible constraints
+
     // Algorithm independent values
     std::string input_file, output_path, images_folder, algorithm, config_path, header;
-    std::string input_format, berclaz_filter;
+    std::string input_format, berclaz_filter, n_stage_constraints;
     bool info, debug, display, output, output_images, show_grid;
     char input_delimiter, output_delimiter;
     double temporal_weight, spatial_weight, angular_weight, image_width, image_height;
@@ -237,6 +240,12 @@ void Run(int argc, char const * const * argv)
                      ->default_value(1.0),
              "(n-stage) the maximum weight an edge can have in the initial graph, edges with"
                      "higher edge weights are discarded")
+            ("n-stage.constraints",
+             boost::program_options::value<std::string>(&n_stage_constraints)
+                     ->default_value(""),
+             "(n-stage) the constraints to ensure when creating the object graph,"
+                     " values and keys are separated by commas,"
+                     " for example: key0,value0,key1,value1")
             ("berclaz.h-res",
              boost::program_options::value<int>(&berclaz_params.h_res)
                      ->default_value(10),
@@ -386,6 +395,12 @@ void Run(int argc, char const * const * argv)
     begin_time = time(0);
     if (algorithm == "n-stage")
     {
+        // Parse the constraints
+        std::vector<std::string> pairs = util::FileIO::Split(n_stage_constraints, ',');
+        for (size_t i = 0; i < pairs.size(); ++i) {
+            n_stage_params.constraints[pairs[i]] = stod(pairs[i + 1]);
+        }
+
         //TODO set the output file name
 
         RunNStage(sequence, tracks);

+ 2 - 1
util/Parser.cpp

@@ -12,14 +12,15 @@ namespace util
 {
     //TODO sigmoid function for detection score smoothing
 
-    //TODO rename if necessary
     const std::string Parser::KEY_FRAME = "frame";
     const std::string Parser::KEY_ID = "id";
     const std::string Parser::KEY_SCORE = "score";
     const std::string Parser::KEY_X = "xc";
     const std::string Parser::KEY_Y = "yc";
+    const std::string Parser::KEY_Z = "zc";
     const std::string Parser::KEY_WIDTH = "width";
     const std::string Parser::KEY_HEIGHT = "height";
+    const std::string Parser::KEY_DEPTH = "depth";
     const std::string Parser::KEY_ANGLE = "angle";
 
     void Parser::ParseObjectData2D(ValueMapVector& values,