NStage.h 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. //
  2. // Created by wrede on 25.04.16.
  3. //
  4. #ifndef GBMOT_NSTAGE_H
  5. #define GBMOT_NSTAGE_H
  6. #include "../core/DetectionSequence.h"
  7. #include "../core/Tracklet.h"
  8. #include "../graph/Definitions.h"
  9. namespace algo
  10. {
  11. /**
  12. * Implementation of the two-staged graph-based multi-object tracker.
  13. * Extended to allow N stages.
  14. */
  15. class NStage
  16. {
  17. private:
  18. /**
  19. * Maximum edge length to link object
  20. */
  21. std::vector<size_t> max_frame_skips_;
  22. /**
  23. * Edge value to link to source and sink
  24. */
  25. std::vector<double> penalty_values_;
  26. /**
  27. * Maximum dijkstra iterations / number of tracklets to create
  28. */
  29. std::vector<size_t> max_tracklet_counts_;
  30. /**
  31. * Number of iterations
  32. */
  33. size_t iterations_;
  34. /**
  35. * The maximum edge weights the edge can have to be created in the initial graph building
  36. */
  37. double edge_weight_threshold_;
  38. /**
  39. * Creates a graph with vertices for every detected object
  40. *
  41. * @param graph The graph to write into
  42. * @param detections The objects to use for the graph
  43. */
  44. void CreateObjectGraph(DirectedGraph & graph, core::DetectionSequence & detections);
  45. /**
  46. * Reduces the object graph into linked tracklets.
  47. *
  48. * @param obj_graph The object graph to reduce
  49. * @param tlt_graph The graph to write the tracklets in
  50. * @param frame_count The frame count of the object graph
  51. * @param iteration The current iteration
  52. */
  53. void CreateTrackletGraph(DirectedGraph & obj_graph, DirectedGraph & tlt_graph,
  54. size_t frame_count, size_t iteration);
  55. /**
  56. * Extracts the finished tracks from the given tracklet graph.
  57. *
  58. * @param tlt_graph The tracklet graph to extract from
  59. * @param depth The depth to flatten the tracklets to
  60. * @param tracks The vector to write the extracted tracks in
  61. */
  62. void ExtractTracks(DirectedGraph& tlt_graph, size_t depth,
  63. std::vector<core::TrackletPtr> & tracks);
  64. public:
  65. /**
  66. * Initializes the algorithm wih the given values.
  67. * The number of stages is determined by the size of the given
  68. * vectors.
  69. *
  70. * @param max_frame_skip The maximum edge length to link objects
  71. * @param penalty_value The edge value to link to source and sink
  72. * @param max_tracklet_count The maximum number of tracklets to create
  73. * @param edge_weight_threshold The maximum weight an edge can have in the initial graph,
  74. * edges with higher weights are discarded
  75. */
  76. NStage(std::vector<size_t> max_frame_skip,
  77. std::vector<double> penalty_value,
  78. std::vector<size_t> max_tracklet_count,
  79. double edge_weight_threshold);
  80. /**
  81. * Runs the algorithm on the specified sequence and stores the found tracks into the
  82. * specified tracks vector.
  83. *
  84. * @param sequence The detection values to use
  85. * @param tracks A vector to store the found tracks in
  86. */
  87. void Run(core::DetectionSequence & sequence,
  88. std::vector<core::TrackletPtr> & tracks);
  89. };
  90. }
  91. #endif //GBMOT_NSTAGE_H