main.cpp 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611
  1. //
  2. // Created by wrede on 19.04.16.
  3. //
  4. #include "../core/DetectionSequence.h"
  5. #include "../util/FileIO.h"
  6. #include "../util/Parser.h"
  7. #include "../algo/NStage.h"
  8. #include "../algo/KShortestPaths.h"
  9. #include "../util/Visualizer.h"
  10. #include "../util/Logger.h"
  11. #include "../core/ObjectDataAngular.h"
  12. #include "../algo/Berclaz.h"
  13. #include <boost/program_options.hpp>
  14. struct
  15. {
  16. size_t max_frame_skip;
  17. std::string max_tracklet_count;
  18. std::string penalty_value;
  19. } n_stage_params;
  20. void RunNStage(core::DetectionSequence& sequence,
  21. const std::string& output_file,
  22. const std::string& images_folder,
  23. bool display)
  24. {
  25. util::Logger::LogInfo("Running n-stage");
  26. std::vector<double> penalty_values;
  27. std::vector<size_t> max_tracklet_counts;
  28. // Parse strings to vectors
  29. size_t d_index;
  30. std::string str, part;
  31. str = n_stage_params.max_tracklet_count;
  32. do
  33. {
  34. d_index = str.find(",");
  35. part = str.substr(0, d_index);
  36. if (part.size() > 0)
  37. {
  38. max_tracklet_counts.push_back((unsigned long&&) std::atoi(part.c_str()));
  39. }
  40. str = str.substr(d_index + 1);
  41. }
  42. while (d_index != std::string::npos);
  43. str = n_stage_params.penalty_value;
  44. do
  45. {
  46. d_index = str.find(",");
  47. part = str.substr(0, d_index);
  48. if (part.size() > 0)
  49. {
  50. penalty_values.push_back(std::atof(part.c_str()));
  51. }
  52. str = str.substr(d_index + 1);
  53. }
  54. while (d_index != std::string::npos);
  55. // Init n stage
  56. algo::NStage n_stage(n_stage_params.max_frame_skip,
  57. penalty_values, max_tracklet_counts);
  58. std::vector<core::TrackletPtr> tracks;
  59. n_stage.Run(sequence, tracks);
  60. // Interpolate tracks
  61. for (auto track : tracks)
  62. {
  63. track->InterpolateMissingFrames();
  64. }
  65. // Display the tracking data
  66. if (display)
  67. {
  68. util::Visualizer vis;
  69. vis.Display(tracks, images_folder);
  70. }
  71. util::Logger::LogInfo("Finished");
  72. }
  73. struct
  74. {
  75. int h_res;
  76. int v_res;
  77. int vicinity_size;
  78. size_t max_track_count;
  79. } berclaz_params;
  80. void RunBerclaz(core::DetectionSequence& sequence,
  81. const std::string& output_file,
  82. const std::string& images_folder,
  83. bool display)
  84. {
  85. util::Logger::LogInfo("Running berclaz");
  86. // Init berclaz
  87. algo::Berclaz berclaz(berclaz_params.h_res,
  88. berclaz_params.v_res,
  89. berclaz_params.vicinity_size);
  90. std::vector<core::TrackletPtr> tracks;
  91. berclaz.Run(sequence, berclaz_params.max_track_count, tracks);
  92. util::Logger::LogInfo("Interpolate tracks");
  93. // Interpolate tracks
  94. for (auto track : tracks)
  95. {
  96. track->InterpolateMissingFrames();
  97. }
  98. // Display the tracking data
  99. if (display)
  100. {
  101. util::Visualizer vis;
  102. vis.Display(tracks, images_folder);
  103. }
  104. util::Logger::LogInfo("Finished");
  105. }
  106. void Run(int argc, char** argv)
  107. {
  108. // Algorithm independent values
  109. std::string input_file, output_file, images_folder, algorithm, config_path;
  110. bool info, debug, display;
  111. // Input dependent variables
  112. std::string header;
  113. double temporal_weight, spatial_weight, angular_weight;
  114. double image_width, image_height;
  115. boost::program_options::options_description opts("Allowed options");
  116. opts.add_options()
  117. ("help",
  118. "produce help message")
  119. ("info",
  120. boost::program_options::value<bool>(&info)
  121. ->default_value(false),
  122. "if the program should show progress information")
  123. ("debug",
  124. boost::program_options::value<bool>(&debug)
  125. ->default_value(false),
  126. "if the program should show debug messages")
  127. ("display",
  128. boost::program_options::value<bool>(&display)
  129. ->default_value(false),
  130. "if a window with the images and the detected tracks should be opened")
  131. ("config",
  132. boost::program_options::value<std::string>(&config_path),
  133. "the path to the config file, if no path is given the command line arguments are read")
  134. ("input-file,i",
  135. boost::program_options::value<std::string>(&input_file),
  136. "set detections file path")
  137. ("output-file,o",
  138. boost::program_options::value<std::string>(&output_file),
  139. "set the output file path")
  140. ("images-folder,f",
  141. boost::program_options::value<std::string>(&images_folder),
  142. "set images folder path")
  143. ("input-header",
  144. boost::program_options::value<std::string>(&header)
  145. ->default_value(""),
  146. "sets the input header, this value is optional if the input file has a header labeling the values,"
  147. "the delimiter used for the header needs to be the same as for the rest of the file")
  148. ("algorithm,a",
  149. boost::program_options::value<std::string>(&algorithm),
  150. "set the algorithm to use, current viable options: n-stage berclaz")
  151. ("max-frame-skip",
  152. boost::program_options::value<size_t>(&n_stage_params.max_frame_skip)
  153. ->default_value(1),
  154. "(n stage) set the maximum number of frames a track can skip between two detections,"
  155. " if set to less or equal than zero all frames are linked")
  156. ("max-tracklet-count",
  157. boost::program_options::value<std::string>(&n_stage_params.max_tracklet_count)
  158. ->default_value("-1,1"),
  159. "(n stage) set the maximum number of tracklets to be extracted")
  160. ("penalty-value",
  161. boost::program_options::value<std::string>(&n_stage_params.penalty_value)
  162. ->default_value("0,0"),
  163. "(n stage) set the penalty value for edges from and to source and sink")
  164. ("temporal-weight",
  165. boost::program_options::value<double>(&temporal_weight)
  166. ->default_value(1.0),
  167. "(n stage) temporal weight for difference calculations between two detections")
  168. ("spatial-weight",
  169. boost::program_options::value<double>(&spatial_weight)
  170. ->default_value(1.0),
  171. "(n stage) spatial weight for difference calculations between two detections")
  172. ("angular-weight",
  173. boost::program_options::value<double>(&angular_weight)
  174. ->default_value(1.0),
  175. "(n stage) angular weight for difference calculations between two detections")
  176. ("horizontal-resolution",
  177. boost::program_options::value<int>(&berclaz_params.h_res)
  178. ->default_value(10),
  179. "(berclaz) the number of horizontal grid cells")
  180. ("vertical-resolution",
  181. boost::program_options::value<int>(&berclaz_params.v_res)
  182. ->default_value(10),
  183. "(berclaz) the number of vertical grid cells")
  184. ("vicinity-size",
  185. boost::program_options::value<int>(&berclaz_params.vicinity_size)
  186. ->default_value(1),
  187. "(berclaz) the vicinity size, the number of cells a detection can travel between two frames")
  188. ("max-track-count",
  189. boost::program_options::value<size_t>(&berclaz_params.max_track_count)
  190. ->default_value(1),
  191. "(berclaz) the maximal number of tracks to extract")
  192. ("image-width",
  193. boost::program_options::value<double>(&image_width)
  194. ->default_value(1920),
  195. "the width of the image")
  196. ("image-height",
  197. boost::program_options::value<double>(&image_height)
  198. ->default_value(1080),
  199. "the height of the image");
  200. boost::program_options::variables_map opt_var_map;
  201. #pragma clang diagnostic push
  202. #pragma clang diagnostic ignored "-Wincompatible-pointer-types-discards-qualifiers"
  203. boost::program_options::store(
  204. boost::program_options::parse_command_line(argc, argv, opts),
  205. opt_var_map);
  206. #pragma clang diagnostic pop
  207. boost::program_options::notify(opt_var_map);
  208. // Display help
  209. if (opt_var_map.count("help") != 0)
  210. {
  211. std::cout << opts << std::endl;
  212. exit(0);
  213. }
  214. // Read config
  215. if (opt_var_map.count("config") != 0)
  216. {
  217. std::ifstream config_file(config_path , std::ifstream::in);
  218. boost::program_options::store(
  219. boost::program_options::parse_config_file(config_file , opts),
  220. opt_var_map);
  221. config_file.close();
  222. boost::program_options::notify(opt_var_map);
  223. }
  224. // Enable info logging
  225. if (info != 0)
  226. {
  227. util::Logger::SetInfo(true);
  228. util::Logger::LogInfo("Enabled");
  229. }
  230. // Enable debug logging
  231. if (debug != 0)
  232. {
  233. util::Logger::SetDebug(true);
  234. util::Logger::LogDebug("Enabled");
  235. }
  236. // Reading the input file
  237. util::Logger::LogInfo("Reading input");
  238. util::ValueMapVector values;
  239. if (header.size() > 0)
  240. {
  241. util::FileIO::ReadCSV(values, header, input_file);
  242. }
  243. else
  244. {
  245. util::FileIO::ReadCSV(values, input_file);
  246. }
  247. // Parsing the read input
  248. core::DetectionSequence sequence;
  249. util::Parser::ParseObjectDataBox(values, sequence,
  250. image_width, image_height,
  251. temporal_weight, spatial_weight);
  252. // Running the specified algorithm
  253. if (algorithm == "n-stage")
  254. {
  255. RunNStage(sequence, output_file, images_folder, display);
  256. }
  257. else if (algorithm == "berclaz")
  258. {
  259. RunBerclaz(sequence, output_file, images_folder, display);
  260. }
  261. else
  262. {
  263. std::cout << opts << std::endl;
  264. exit(0);
  265. }
  266. }
  267. void CreateTestGraph(DirectedGraph& graph, Vertex& source, Vertex& sink)
  268. {
  269. // Create test graph (suurballe wikipedia example)
  270. // std::vector<Vertex> vertices;
  271. // for (size_t i = 0; i < 6; ++i)
  272. // {
  273. // vertices.push_back(
  274. // boost::add_vertex(
  275. // core::ObjectDataPtr(new core::ObjectData(i)),graph));
  276. // }
  277. //
  278. // // AB
  279. // boost::add_edge(vertices[0], vertices[1], 1.0, graph);
  280. //
  281. // // AC
  282. // boost::add_edge(vertices[0], vertices[2], 2.0, graph);
  283. //
  284. // // BD
  285. // boost::add_edge(vertices[1], vertices[3], 1.0, graph);
  286. //
  287. // // BE
  288. // boost::add_edge(vertices[1], vertices[4], 2.0, graph);
  289. //
  290. // // CD
  291. // boost::add_edge(vertices[2], vertices[3], 2.0, graph);
  292. //
  293. // // DF
  294. // boost::add_edge(vertices[3], vertices[5], 1.0, graph);
  295. //
  296. // // EF
  297. // boost::add_edge(vertices[4], vertices[5], 2.0, graph);
  298. //
  299. // source = vertices[0];
  300. // sink = vertices[5];
  301. // Create test graph
  302. std::vector<Vertex> vertices;
  303. for (size_t i = 0; i < 11; ++i)
  304. {
  305. vertices.push_back(
  306. boost::add_vertex(
  307. core::ObjectDataPtr(new core::ObjectData(i)),graph));
  308. }
  309. boost::add_edge(vertices[0], vertices[1], 1.0, graph);
  310. boost::add_edge(vertices[0], vertices[8], 1.0, graph);
  311. boost::add_edge(vertices[0], vertices[4], 1.0, graph);
  312. boost::add_edge(vertices[1], vertices[2], 2.0, graph);
  313. boost::add_edge(vertices[1], vertices[5], 1.0, graph);
  314. boost::add_edge(vertices[2], vertices[3], 1.0, graph);
  315. boost::add_edge(vertices[2], vertices[6], 2.0, graph);
  316. boost::add_edge(vertices[2], vertices[10], 2.0, graph);
  317. boost::add_edge(vertices[3], vertices[7], 1.0, graph);
  318. boost::add_edge(vertices[4], vertices[2], 1.0, graph);
  319. boost::add_edge(vertices[4], vertices[5], 2.0, graph);
  320. boost::add_edge(vertices[4], vertices[9], 2.0, graph);
  321. boost::add_edge(vertices[5], vertices[6], 1.0, graph);
  322. boost::add_edge(vertices[5], vertices[3], 2.0, graph);
  323. boost::add_edge(vertices[6], vertices[7], 1.0, graph);
  324. boost::add_edge(vertices[8], vertices[2], 2.0, graph);
  325. boost::add_edge(vertices[8], vertices[9], 1.0, graph);
  326. boost::add_edge(vertices[9], vertices[3], 2.0, graph);
  327. boost::add_edge(vertices[9], vertices[10], 1.0, graph);
  328. boost::add_edge(vertices[10], vertices[7], 1.0, graph);
  329. source = vertices[0];
  330. sink = vertices[7];
  331. }
  332. void TestKSP()
  333. {
  334. Vertex source, sink;
  335. DirectedGraph graph;
  336. util::Logger::SetDebug(true);
  337. util::Logger::SetInfo(true);
  338. CreateTestGraph(graph, source, sink);
  339. algo::KShortestPaths ksp(graph, source, sink);
  340. MultiPredecessorMap paths = ksp.Run(5);
  341. util::FileIO::WriteCSVMatlab(graph, "/home/wrede/Dokumente/graph.csv");
  342. util::FileIO::WriteCSVMatlab(paths,
  343. sink, "/home/wrede/Dokumente/paths.csv");
  344. }
  345. void TestGrid()
  346. {
  347. int lower_index = 0;
  348. int upper_index = 5;
  349. double lower_bound = 0.0;
  350. double upper_bound = 50.0;
  351. util::Grid grid(upper_index, upper_index, upper_index,
  352. upper_bound, upper_bound, upper_bound);
  353. std::uniform_int_distribution<int> unii(lower_index, upper_index - 1);
  354. std::uniform_real_distribution<double> unif(lower_bound, upper_bound);
  355. std::default_random_engine re;
  356. // Fill with empty values
  357. std::cout << "fill with empty values\n";
  358. for (int z = lower_index; z < upper_index; ++z)
  359. {
  360. for (int y = lower_index; y < upper_index; ++y)
  361. {
  362. for (int x = lower_index; y < upper_index; ++y)
  363. {
  364. grid.SetValue(nullptr, x, y, z);
  365. }
  366. }
  367. }
  368. // Randomly add data
  369. std::cout << "randomly add data\n";
  370. for (int i = 0; i < 10; ++i)
  371. {
  372. int xi = unii(re);
  373. int yi = unii(re);
  374. int zi = unii(re);
  375. core::ObjectDataPtr value(new core::ObjectData((size_t)i));
  376. grid.SetValue(value, xi, yi, zi);
  377. std::cout << xi << "," << yi << "," << zi << " = " << *value << std::endl;
  378. }
  379. // Randomly get data
  380. std::cout << "randomly get data\n";
  381. for (int i = 0; i < 10; ++i)
  382. {
  383. double x = unif(re);
  384. double y = unif(re);
  385. double z = unif(re);
  386. std::cout << x << "," << y << "," << z << " = ";
  387. core::ObjectDataPtr value = grid.GetValue(x, y, z);
  388. if (value)
  389. {
  390. std::cout << *value << std::endl;
  391. }
  392. else
  393. {
  394. std::cout << "nullptr" << std::endl;
  395. }
  396. }
  397. }
  398. void TestBerclazGraph()
  399. {
  400. std::cout << "init\n";
  401. // Init grid with data
  402. util::Grid grid(3, 3, 3, 9.0, 9.0, 9.0);
  403. for (int z = 0; z < grid.GetDepthCount(); ++z)
  404. {
  405. for (int y = 0; y < grid.GetHeightCount(); ++y)
  406. {
  407. for (int x = 0; x < grid.GetWidthCount(); ++x)
  408. {
  409. core::ObjectDataPtr value(new core::ObjectData(10));
  410. grid.SetValue(value, x, y, z);
  411. }
  412. }
  413. }
  414. // Add path source->0,0,0->0,0,1->0,0,2->sink
  415. core::ObjectDataPtr value0(new core::ObjectData(1));
  416. value0->SetDetectionScore(1.0);
  417. grid.SetValue(value0, 0, 0, 0);
  418. core::ObjectDataPtr value1(new core::ObjectData(2));
  419. value1->SetDetectionScore(1.0);
  420. grid.SetValue(value1, 0, 0, 1);
  421. core::ObjectDataPtr value2(new core::ObjectData(3));
  422. value2->SetDetectionScore(1.0);
  423. grid.SetValue(value2, 0, 0, 2);
  424. // Add path source->1,1,0->1,1,0->1,1,2->sink
  425. core::ObjectDataPtr value3(new core::ObjectData(4));
  426. value3->SetDetectionScore(0.6);
  427. grid.SetValue(value3, 0, 1, 0);
  428. core::ObjectDataPtr value4(new core::ObjectData(5));
  429. value4->SetDetectionScore(0.6);
  430. grid.SetValue(value4, 0, 1, 1);
  431. core::ObjectDataPtr value5(new core::ObjectData(6));
  432. value5->SetDetectionScore(0.6);
  433. grid.SetValue(value5, 0, 1, 2);
  434. // Add path source->2,2,0->2,2,0->2,2,2->sink
  435. core::ObjectDataPtr value6(new core::ObjectData(7));
  436. value6->SetDetectionScore(0.3);
  437. grid.SetValue(value6, 0, 2, 0);
  438. core::ObjectDataPtr value7(new core::ObjectData(8));
  439. value7->SetDetectionScore(0.3);
  440. grid.SetValue(value7, 0, 2, 1);
  441. core::ObjectDataPtr value8(new core::ObjectData(9));
  442. value8->SetDetectionScore(0.3);
  443. grid.SetValue(value8, 0, 2, 2);
  444. std::cout << "add vertices\n";
  445. // Add grid vertices
  446. DirectedGraph graph;
  447. for (int z = 0; z < grid.GetDepthCount(); ++z)
  448. {
  449. for (int y = 0; y < grid.GetHeightCount(); ++y)
  450. {
  451. for (int x = 0; x < grid.GetWidthCount(); ++x)
  452. {
  453. boost::add_vertex(grid.GetValue(x, y, z), graph);
  454. }
  455. }
  456. }
  457. std::cout << "vertex count = " << boost::num_vertices(graph) << std::endl;
  458. std::cout << "edge count = " << boost::num_edges(graph) << std::endl;
  459. // Add source and sink vertex
  460. Vertex source = boost::add_vertex(core::ObjectDataPtr(new core::ObjectData()), graph);
  461. Vertex sink = boost::add_vertex(core::ObjectDataPtr(new core::ObjectData()), graph);
  462. std::cout << "add edges\n";
  463. // Iterate all vertices but source and sink
  464. VertexIndexMap vertices = boost::get(boost::vertex_index, graph);
  465. VertexValueMap values = boost::get(boost::vertex_name, graph);
  466. int vicinity_size = 1;
  467. int layer_size = grid.GetWidthCount() * grid.GetHeightCount();
  468. for (int z = 0; z < grid.GetDepthCount(); ++z)
  469. {
  470. for (int y = 0; y < grid.GetHeightCount(); ++y)
  471. {
  472. for (int x = 0; x < grid.GetWidthCount(); ++x)
  473. {
  474. // First vertex index
  475. int vi = x + y * grid.GetHeightCount() + z * layer_size;
  476. // Connect with the next frame only if there is a next frame
  477. if (z < grid.GetDepthCount() - 1)
  478. {
  479. // Get the score, clamp it, prevent division by zero and
  480. // logarithm of zero
  481. double score = values[vi]->GetDetectionScore();
  482. if (score > 0.999999)
  483. {
  484. score = 0.999999;
  485. }
  486. else if (score < 0.000001)
  487. {
  488. score = 0.000001;
  489. }
  490. // Calculate the edge weight
  491. double weight = -std::log(score / (1 - score));
  492. // Iterate all nearby cells in the next frame
  493. for (int ny = std::max(0, y - vicinity_size);
  494. ny <
  495. std::min(grid.GetHeightCount(), y + vicinity_size + 1);
  496. ++ny)
  497. {
  498. for (int nx = std::max(0, x - vicinity_size);
  499. nx < std::min(grid.GetWidthCount(),
  500. x + vicinity_size + 1);
  501. ++nx)
  502. {
  503. // Second vertex index
  504. int vj = nx + ny * grid.GetHeightCount() +
  505. (z + 1) * layer_size;
  506. // Connect to nearby cells
  507. boost::add_edge(vertices[vi], vertices[vj],
  508. weight, graph);
  509. }
  510. }
  511. }
  512. // Connect with source and sink
  513. boost::add_edge(source, vertices[vi], 0.0, graph);
  514. boost::add_edge(vertices[vi], sink, 0.0, graph);
  515. }
  516. }
  517. }
  518. std::cout << "vertex count = " << boost::num_vertices(graph) << std::endl;
  519. std::cout << "edge count = " << boost::num_edges(graph) << std::endl;
  520. // Running KSP with 5 possible paths although only 3 are worth it
  521. algo::KShortestPaths ksp(graph, source, sink);
  522. MultiPredecessorMap ksp_result = ksp.Run(5);
  523. util::FileIO::WriteCSVMatlab(graph, "/home/wrede/Dokumente/graph.csv");
  524. util::FileIO::WriteCSVMatlab(ksp_result,
  525. sink, "/home/wrede/Dokumente/paths.csv");
  526. }
  527. int main(int argc, char** argv)
  528. {
  529. //Run(argc, argv);
  530. //TestTracklet();
  531. TestKSP();
  532. //TestGrid();
  533. //TestBerclazGraph();
  534. return 0;
  535. }