NStage.h 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. //
  2. // Created by wrede on 25.04.16.
  3. //
  4. #ifndef GBMOT_NSTAGE_H
  5. #define GBMOT_NSTAGE_H
  6. #include "../core/DetectionSequence.h"
  7. #include "../core/Tracklet.h"
  8. #include "../graph/Definitions.h"
  9. namespace algo
  10. {
  11. /**
  12. * Implementation of the two-staged graph-based multi-object tracker.
  13. * Extended to allow N stages.
  14. */
  15. class NStage
  16. {
  17. private:
  18. /**
  19. * Maximum edge length to link object
  20. */
  21. std::vector<size_t> max_frame_skips_;
  22. /**
  23. * Edge value to link to source and sink
  24. */
  25. std::vector<double> penalty_values_;
  26. /**
  27. * Maximum dijkstra iterations / number of tracklets to create
  28. */
  29. std::vector<size_t> max_tracklet_counts_;
  30. /**
  31. * Number of iterations
  32. */
  33. size_t iterations_;
  34. /**
  35. * Creates a graph with vertices for every detected object
  36. * @param graph The graph to write into
  37. * @param detections The objects to use for the graph
  38. */
  39. void CreateObjectGraph(DirectedGraph& graph, const core::DetectionSequence& detections);
  40. /**
  41. * Reduces the object graph into linked tracklets.
  42. * @param obj_graph The object graph to reduce
  43. * @param tlt_graph The graph to write the tracklets in
  44. * @param frame_count The frame count of the object graph
  45. * @param iteration The current iteration
  46. */
  47. void CreateTrackletGraph(DirectedGraph& obj_graph, DirectedGraph& tlt_graph,
  48. size_t frame_count, size_t iteration);
  49. /**
  50. * Extracts the finished tracks from the given tracklet graph.
  51. * @param tlt_graph The tracklet graph to extract from
  52. * @param depth The depth to flatten the tracklets to
  53. * @param tracks The vector to write the extracted tracks in
  54. */
  55. void ExtractTracks(DirectedGraph& tlt_graph, size_t depth,
  56. std::vector<core::TrackletPtr>& tracks);
  57. public:
  58. /**
  59. * Initializes the algorithm wih the given values.
  60. * The number of stages is determined by the size of the given
  61. * vectors.
  62. * @param max_frame_skip The maximum edge length to link objects
  63. * @param penalty_value The edge value to link to source and sink
  64. * @param max_tracklet_count The maximum number of tracklets to create
  65. */
  66. NStage(std::vector<size_t> max_frame_skip,
  67. std::vector<double> penalty_value,
  68. std::vector<size_t> max_tracklet_count);
  69. /**
  70. * Runs the algorithm on the specified sequence and stores the found tracks into the
  71. * specified tracks vector.
  72. *
  73. * @param sequence The detection values to use
  74. * @param tracks A vector to store the found tracks in
  75. */
  76. void Run(const core::DetectionSequence& sequence,
  77. std::vector<core::TrackletPtr>& tracks);
  78. };
  79. }
  80. #endif //GBMOT_NSTAGE_H