Berclaz.cpp 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. //
  2. // Created by wrede on 02.06.16.
  3. //
  4. #include "Berclaz.h"
  5. #include "KShortestPaths.h"
  6. #include "../util/Parser.h"
  7. #include "../util/Logger.h"
  8. #include "../util/FileIO.h"
  9. namespace algo
  10. {
  11. Berclaz::Berclaz(int h_res, int v_res, int vicinity_size)
  12. {
  13. h_res_ = h_res;
  14. v_res_ = v_res;
  15. vicinity_size_ = vicinity_size;
  16. }
  17. void Berclaz::CreateGraph(DirectedGraph& graph, Vertex& source, Vertex& sink,
  18. util::Grid& grid)
  19. {
  20. // Add source vertex
  21. source = boost::add_vertex(core::ObjectDataPtr(new core::ObjectData()), graph);
  22. // Add grid vertices
  23. for (int f = 0; f < grid.GetDepthCount(); ++f)
  24. {
  25. for (int y = 0; y < v_res_; ++y)
  26. {
  27. for (int x = 0; x < h_res_; ++x)
  28. {
  29. boost::add_vertex(grid.GetValue(x, y, f), graph);
  30. }
  31. }
  32. }
  33. // Add sink vertex
  34. sink = boost::add_vertex(core::ObjectDataPtr(new core::ObjectData()), graph);
  35. // Store the vertex indices
  36. VertexIndexMap vertices = boost::get(boost::vertex_index, graph);
  37. // Store the vertex values
  38. VertexValueMap values = boost::get(boost::vertex_name, graph);
  39. util::Logger::LogDebug("num vertices " + std::to_string(boost::num_vertices(graph)));
  40. // Iterate all vertices but source and sink
  41. for (int f = 0; f < grid.GetDepthCount(); ++f)
  42. {
  43. for (int y = 0; y < v_res_; ++y)
  44. {
  45. for (int x = 0; x < h_res_; ++x)
  46. {
  47. // First vertex index
  48. int vi = x + y * h_res_ + f * h_res_ * v_res_ + 1;
  49. double score = values[vi]->GetDetectionScore();
  50. if (score > MAX_SCORE_VALUE)
  51. {
  52. score = MAX_SCORE_VALUE;
  53. }
  54. else if (score < MIN_SCORE_VALUE)
  55. {
  56. score = MIN_SCORE_VALUE;
  57. }
  58. // Iterate all nearby cells in the next frame
  59. for (int nx = std::max(0, x - vicinity_size_);
  60. nx < std::min(h_res_, x + vicinity_size_ + 1);
  61. ++nx)
  62. {
  63. for (int ny = std::max(0, y - vicinity_size_);
  64. ny < std::min(v_res_, y + vicinity_size_ + 1);
  65. ++ny)
  66. {
  67. // Second vertex index
  68. int vj = nx + ny * h_res_ + (f + 1) * h_res_ * v_res_ + 1;
  69. // Connect to nearby cells
  70. double weight = -std::log(score / (1 - score));
  71. boost::add_edge(vertices[vi], vertices[vj],
  72. weight, graph);
  73. }
  74. }
  75. // Connect with source and sink
  76. boost::add_edge(source, vertices[vi],
  77. VIRTUAL_EDGE_WEIGHT, graph);
  78. boost::add_edge(vertices[vi], sink,
  79. VIRTUAL_EDGE_WEIGHT, graph);
  80. }
  81. }
  82. }
  83. util::Logger::LogDebug("num edges " + std::to_string(boost::num_edges(graph)));
  84. }
  85. void Berclaz::ExtractTracks(DirectedGraph& graph,
  86. MultiPredecessorMap& map, Vertex origin,
  87. std::vector<core::TrackletPtr>& tracks)
  88. {
  89. VertexValueMap values = boost::get(boost::vertex_name, graph);
  90. // Move along all paths in reverse, starting at the origin
  91. for (Vertex first : map[origin])
  92. {
  93. core::TrackletPtr tracklet(new core::Tracklet());
  94. // The paths are node disjoint, so there should always be only one
  95. // node to proceed to
  96. for (Vertex u = first, v = (*map[u].begin());
  97. u != v; u = v, v = (*map[v].begin()))
  98. {
  99. tracklet->AddPathObject(values[u]);
  100. }
  101. tracks.push_back(tracklet);
  102. }
  103. }
  104. void Berclaz::Run(core::DetectionSequence& sequence,
  105. size_t max_track_count,
  106. std::vector<core::TrackletPtr>& tracks)
  107. {
  108. Vertex source, sink;
  109. DirectedGraph graph;
  110. util::Grid grid = util::Parser::ParseGrid(sequence,
  111. MIN_H_VALUE,
  112. MAX_H_VALUE,
  113. h_res_,
  114. MIN_V_VALUE,
  115. MAX_V_VALUE,
  116. v_res_);
  117. util::Logger::LogDebug("create graph");
  118. CreateGraph(graph, source, sink, grid);
  119. util::FileIO::WriteCSVMatlab(graph, "/home/wrede/Dokumente/graph.csv");
  120. util::Logger::LogDebug("init ksp");
  121. KShortestPaths ksp(graph, source, sink);
  122. util::Logger::LogDebug("run ksp");
  123. MultiPredecessorMap ksp_result = ksp.Run(max_track_count);
  124. util::Logger::LogDebug("extract tracks");
  125. ExtractTracks(graph, ksp_result, sink, tracks);
  126. }
  127. }