Project.cpp 16 KB


  1. /*
  2. * This file is part of the Carpe Diem Active Learning Software,
  3. * Copyright (C) 2017 Clemens-Alexander Brust (ikosa dot de at gmail dot com).
  4. *
  5. * For licensing information, see the LICENSE file included with this project.
  6. */
  7. #include "Project.h"
  8. #include <string>
  9. #include <sstream>
  10. #include <chrono>
  11. #include <iomanip>
  12. #include <ctime>
  13. #include <cn24.h>
  14. Project::Project(QObject* _parent, Conv::ClassManager::ClassUpdateHandler* class_update_handler, ProjectStateHandler* project_state_handler)
  15. : QObject(_parent), class_update_handler_(class_update_handler), state_handler_(project_state_handler), state(NOTHING)
  16. {
  17. }
  18. void Project::PredictSegment(ProjectProgressHandler *progress_handler, Conv::Segment *segment, std::vector<std::string> *prediction_filenames, std::vector<std::vector<Conv::BoundingBox> > *prediction_boxes) {
  19. if(state != Project::LOADED)
  20. return;
  21. progress_handler->OnProjectProgressUpdate(0);
  22. // Switch to testing mode
  23. graph_->SetIsTesting(true);
  24. input_layer_->ForceWeightsZero();
  25. Conv::datum total_sample_count = segment->GetSampleCount();
  26. Conv::datum running_sample_count = 0;
  27. for(unsigned int sample = 0; sample < segment->GetSampleCount(); sample++) {
  28. Conv::JSON& sample_json = segment->GetSample(sample);
  29. std::vector<Conv::BoundingBox> sample_predictions;
  30. // Load sample
  31. input_layer_->ForceLoadDetection(sample_json, 0);
  32. graph_->FeedForward();
  33. // Copy predictions
  34. Conv::DetectionMetadataPointer output_boxes = predicted_metadata_[0];
  35. for(unsigned int b = 0; b < output_boxes->size(); b++) {
  36. Conv::BoundingBox bbox = output_boxes->at(b);
  37. sample_predictions.push_back(bbox);
  38. }
  39. // Store predictions
  40. std::string sample_filename = sample_json["image_rpath"];
  41. prediction_filenames->push_back(sample_filename);
  42. prediction_boxes->push_back(sample_predictions);
  43. running_sample_count += 1.0;
  44. progress_handler->OnProjectProgressUpdate(running_sample_count / total_sample_count);
  45. }
  46. // Done
  47. progress_handler->OnProjectProgressUpdate(1);
  48. progress_handler->OnProjectProgressDone();
  49. }
  50. void Project::UpdateModel(ProjectProgressHandler *progress_handler) {
  51. if(update_set_->GetSampleCount() == 0) {
  52. progress_handler->OnProjectProgressDone();
  53. return;
  54. }
  55. needs_rescore_ = true;
  56. // Update input layer settings
  57. input_layer_->training_sets_.clear();
  58. input_layer_->training_weights_.clear();
  59. input_layer_->training_sets_.push_back(known_samples_);
  60. input_layer_->training_sets_.push_back(update_set_);
  61. input_layer_->training_weights_.push_back(1);
  62. input_layer_->training_weights_.push_back(1);
  63. input_layer_->UpdateDatasets();
  64. ProjectTrainerProgressHandler trainer_progress_handler(progress_handler);
  65. progress_handler->OnProjectProgressUpdate(0);
  66. trainer_->SetUpdateHandler(&trainer_progress_handler);
  67. trainer_->settings()["epoch_iterations"] = 10 * update_set_->GetSampleCount();
  68. trainer_->Train(1, false);
  69. progress_handler->OnProjectProgressUpdate(1);
  70. Save();
  71. progress_handler->OnProjectProgressDone();
  72. }
  73. void Project::UpdateScores(ProjectProgressHandler *progress_handler) {
  74. progress_handler->OnProjectProgressUpdate(0);
  75. if(!needs_rescore_) {
  76. progress_handler->OnProjectProgressDone();
  77. return;
  78. }
  79. Conv::JSONNetGraphFactory netgraph_factory(architecture_json, 123123);
  80. Conv::ActiveLearningPolicy* policy = Conv::YOLOActiveLearningPolicy::CreateWithName(active_learning_policy_,netgraph_factory.GetYOLOConfiguration());
  81. Conv::NetGraphBuffer& prediction_buffer = graph_->GetOutputNodes()[0]->output_buffers[0];
  82. Conv::DatasetMetadataPointer* predicted_metadata = prediction_buffer.combined_tensor->metadata;
  83. input_layer_->ForceWeightsZero();
  84. graph_->SetIsTesting(true);
  85. Conv::datum total_sample_count = new_set_->GetSampleCount();
  86. Conv::datum running_sample_count = 0;
  87. unsigned int batch_size = prediction_buffer.combined_tensor->data.samples();
  88. for(unsigned int s = 0; s < new_set_->GetSegmentCount(); s++) {
  89. Conv::Segment* segment = new_set_->GetSegment(s);
  90. Conv::datum segment_score = 0;
  91. for(unsigned int sample = 0; sample < segment->GetSampleCount(); sample+= batch_size) {
  92. for(unsigned int bindex = 0; bindex < batch_size && (sample+bindex) < (segment->GetSampleCount()); bindex++) {
  93. Conv::JSON& sample_json = segment->GetSample(sample+bindex);
  94. input_layer_->ForceLoadDetection(sample_json, bindex);
  95. }
  96. graph_->FeedForward();
  97. for(unsigned int bindex = 0; bindex < batch_size && (sample+bindex) < (segment->GetSampleCount()); bindex++) {
  98. segment_score += policy->Score(prediction_buffer.combined_tensor->data, predicted_metadata, bindex);
  99. running_sample_count += 1.0;
  100. progress_handler->OnProjectProgressUpdate(running_sample_count / total_sample_count);
  101. }
  102. }
  103. segment->score = segment_score;
  104. }
  105. delete policy;
  106. needs_rescore_ = false;
  107. progress_handler->OnProjectProgressDone();
  108. }
  109. bool Project::Save() {
  110. bool model_result = SaveModel();
  111. if(!model_result)
  112. return false;
  113. Conv::JSON project_json = Serialize();
  114. std::string project_filename = project_folder_ + "/project.json";
  115. std::ofstream project_file(project_filename, std::ios::out);
  116. if(!project_file.good()) {
  117. LOGERROR << "Could not open " << project_filename << " for writing";
  118. return false;
  119. }
  120. project_file << project_json.dump(2);
  121. return true;
  122. }
  123. bool Project::Load(std::string project_folder) {
  124. needs_rescore_ = true;
  125. project_folder_ = project_folder;
  126. std::string project_filename = project_folder_ + "/project.json";
  127. std::ifstream project_file(project_filename, std::ios::in);
  128. if(!project_file.good()) {
  129. LOGERROR << "Could not open " << project_filename << " for reading";
  130. return false;
  131. }
  132. Conv::JSON project_json = Conv::JSON::parse(project_file);
  133. bool project_result = Deserialize(project_json);
  134. if(!project_result) {
  135. SetState(FAILED);
  136. return false;
  137. }
  138. SetState(Project::LOADED);
  139. // Load Model
  140. bool model_result = LoadModel();
  141. if(!model_result) {
  142. SetState(FAILED);
  143. return false;
  144. }
  145. return true;
  146. }
  147. bool Project::AddSample(std::string filename) {
  148. needs_rescore_ = true;
  149. Conv::Segment* target_segment = nullptr;
  150. // If no new samples have been added, create first segment
  151. if(new_set_->GetSegmentCount() == 0) {
  152. std::stringstream ss; ss << "New Data Batch " << 1;
  153. Conv::Segment* segment = new Conv::Segment(ss.str());
  154. new_set_->AddSegment(segment);
  155. target_segment = segment;
  156. } else {
  157. // Otherwise, get last segment and check if it has room for new sample
  158. Conv::Segment* segment = new_set_->GetSegment(new_set_->GetSegmentCount() - 1);
  159. if(segment->GetSampleCount() < new_batch_size_) {
  160. target_segment = segment;
  161. } else {
  162. // No room, create new segment
  163. std::stringstream ss; ss << "New Data Batch " << new_set_->GetSegmentCount() + 1;
  164. Conv::Segment* segment = new Conv::Segment(ss.str());
  165. new_set_->AddSegment(segment);
  166. target_segment = segment;
  167. }
  168. }
  169. Conv::JSON sample_json = Conv::JSON::object();
  170. sample_json["image_filename"] = filename;
  171. sample_json["boxes"] = Conv::JSON::array();
  172. return target_segment->AddSample(sample_json);
  173. }
  174. void Project::Predict(std::string image_filename, std::vector<Conv::BoundingBox> &predictions) {
  175. if(state != Project::LOADED)
  176. return;
  177. std::string found_path = Conv::PathFinder::FindPath(image_filename, "");
  178. if(found_path.length() > 0) {
  179. Conv::JSON sample_json = Conv::JSON::object();
  180. sample_json["image_rpath"] = found_path;
  181. sample_json["boxes"] = Conv::JSON::array();
  182. input_layer_->ForceWeightsZero();
  183. input_layer_->ForceLoadDetection(sample_json, 0);
  184. graph_->SetIsTesting(true);
  185. graph_->FeedForward();
  186. Conv::DetectionMetadataPointer output_boxes = predicted_metadata_[0];
  187. LOGINFO << "Predicted " << output_boxes->size() << " boxes.";
  188. for(unsigned int b = 0; b < output_boxes->size(); b++) {
  189. Conv::BoundingBox bbox = output_boxes->at(b);
  190. predictions.push_back(bbox);
  191. }
  192. } else {
  193. LOGERROR << "Could not find " << image_filename << "!";
  194. }
  195. }
  196. Conv::JSON Project::Serialize() {
  197. Conv::JSON project_json = Conv::JSON::object();
  198. project_json["architecture"] = architecture_json;
  199. project_json["update_set"] = update_set_->Serialize();
  200. project_json["known_set"] = known_samples_->Serialize();
  201. project_json["new_set"] = new_set_->Serialize();
  202. project_json["name"] = project_name_;
  203. project_json["new_batch_size"] = new_batch_size_;
  204. project_json["active_learning_policy"] = active_learning_policy_;
  205. return project_json;
  206. }
  207. bool Project::Deserialize(Conv::JSON& project_json) {
  208. if(state != Project::NOTHING) {
  209. LOGERROR << "Already have a project!";
  210. return false;
  211. }
  212. // Load JSON
  213. architecture_filename_="_from_json_";
  214. if(!project_json["architecture"].is_object()) {
  215. LOGERROR << "Project JSON is missing architecture";
  216. return false;
  217. }
  218. if(!project_json["update_set"].is_object() || !project_json["known_set"].is_object()) {
  219. LOGERROR << "Project JSON is missing set informations!";
  220. return false;
  221. }
  222. if(project_json.count("new_batch_size") == 1 && project_json["new_batch_size"].is_number()) {
  223. new_batch_size_ = project_json["new_batch_size"];
  224. }
  225. if(project_json.count("active_learning_policy") == 1 && project_json["active_learning_policy"].is_string()) {
  226. active_learning_policy_ = project_json["active_learning_policy"];
  227. }
  228. architecture_json = project_json["architecture"];
  229. // Create class manager
  230. class_manager_ = new Conv::ClassManager();
  231. class_manager_->RegisterClassUpdateHandler(class_update_handler_);
  232. // Load architecture
  233. Conv::JSONNetGraphFactory netgraph_factory(architecture_json, 123123);
  234. graph_ = new Conv::NetGraph();
  235. // Create dataset input layer
  236. unsigned int batch_size_parallel = 1;
  237. if(netgraph_factory.GetHyperparameters().count("batch_size_parallel") == 1 && netgraph_factory.GetHyperparameters()["batch_size_parallel"].is_number()) {
  238. batch_size_parallel = netgraph_factory.GetHyperparameters()["batch_size_parallel"];
  239. }
  240. input_layer_ = new Conv::SegmentSetInputLayer(netgraph_factory.GetDataInput(), Conv::DETECTION, class_manager_, batch_size_parallel, 123923);
  241. Conv::NetGraphNode* input_node = new Conv::NetGraphNode(input_layer_);
  242. input_node->is_input = true;
  243. // Add other layers
  244. graph_->AddNode(input_node);
  245. bool result = netgraph_factory.AddLayers(*graph_, class_manager_, 23923);
  246. if(!result) {
  247. SetState(Project::FAILED);
  248. LOGERROR << "Could not construct network!";
  249. return false;
  250. }
  251. graph_->Initialize();
  252. graph_->InitializeWeights(true);
  253. // Set helper pointers
  254. predicted_metadata_ = (Conv::DetectionMetadataPointer*) graph_->GetOutputNodes()[0]->output_buffers[0].combined_tensor->metadata;
  255. // Set project properties
  256. project_name_ = project_json["name"];
  257. // Load trainer
  258. trainer_ = new Conv::Trainer(*graph_, netgraph_factory.GetHyperparameters());
  259. // Load samples
  260. known_samples_ = new Conv::SegmentSet("Known Examples");
  261. bool deserialization_result = known_samples_->Deserialize(project_json["known_set"]);
  262. update_set_ = new Conv::SegmentSet("Update Set");
  263. deserialization_result &= update_set_->Deserialize(project_json["update_set"]);
  264. new_set_ = new Conv::SegmentSet("New Set");
  265. deserialization_result &= new_set_->Deserialize(project_json["new_set"]);
  266. if(!deserialization_result) {
  267. LOGERROR << "SegmentSet deserialization failed! See log for details.";
  268. return false;
  269. }
  270. return true;
  271. }
  272. bool Project::New(std::string architecture_filename, std::string model_filename, std::string project_name, std::string project_folder) {
  273. needs_rescore_ = true;
  274. if(state != Project::NOTHING) {
  275. LOGERROR << "Already have a project!";
  276. return false;
  277. } else {
  278. // Validate filenames
  279. std::ifstream architecture_file(architecture_filename, std::ios::in);
  280. if(!architecture_file.good()) {
  281. LOGERROR << "Failed to open architecture!";
  282. return false;
  283. }
  284. std::ifstream model_file(model_filename, std::ios::in | std::ios::binary);
  285. if(!model_file.good()) {
  286. LOGERROR << "Failed to open model!";
  287. }
  288. // Create class manager
  289. class_manager_ = new Conv::ClassManager();
  290. class_manager_->RegisterClassUpdateHandler(class_update_handler_);
  291. // Load architecture
  292. architecture_json = Conv::JSON::parse(architecture_file);
  293. Conv::JSONNetGraphFactory netgraph_factory(architecture_json, 123123);
  294. graph_ = new Conv::NetGraph();
  295. // Create dataset input layer
  296. unsigned int batch_size_parallel = 1;
  297. if(netgraph_factory.GetHyperparameters().count("batch_size_parallel") == 1 && netgraph_factory.GetHyperparameters()["batch_size_parallel"].is_number()) {
  298. batch_size_parallel = netgraph_factory.GetHyperparameters()["batch_size_parallel"];
  299. }
  300. input_layer_ = new Conv::SegmentSetInputLayer(netgraph_factory.GetDataInput(), Conv::DETECTION, class_manager_, batch_size_parallel, 123923);
  301. Conv::NetGraphNode* input_node = new Conv::NetGraphNode(input_layer_);
  302. input_node->is_input = true;
  303. // Add other layers
  304. graph_->AddNode(input_node);
  305. bool result = netgraph_factory.AddLayers(*graph_, class_manager_, 23923);
  306. if(!result) {
  307. SetState(Project::FAILED);
  308. LOGERROR << "Could not construct network!";
  309. return false;
  310. }
  311. graph_->Initialize();
  312. graph_->InitializeWeights(true);
  313. // Load model
  314. graph_->DeserializeParameters(model_file);
  315. // Load trainer
  316. trainer_ = new Conv::Trainer(*graph_, netgraph_factory.GetHyperparameters());
  317. // Initialize segment sets
  318. known_samples_ = new Conv::SegmentSet("Known Examples");
  319. update_set_ = new Conv::SegmentSet("Update Set");
  320. new_set_ = new Conv::SegmentSet("New Set");
  321. // Set helper pointers
  322. predicted_metadata_ = (Conv::DetectionMetadataPointer*) graph_->GetOutputNodes()[0]->output_buffers[0].combined_tensor->metadata;
  323. // Set project properties
  324. project_name_ = project_name;
  325. project_folder_ = project_folder;
  326. SetState(Project::LOADED);
  327. return true;
  328. }
  329. }
  330. void Project::SaveSnapshot() {
  331. if(state != LOADED)
  332. return;
  333. std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
  334. std::time_t now_c = std::chrono::system_clock::to_time_t(now);
  335. std::stringstream ss;
  336. ss << project_folder_ << "/" << "snapshot-" << std::put_time(std::localtime(&now_c), "%y%m%d-%H%M") << "-" << known_samples_->GetSampleCount() << ".CNParamX";
  337. std::string model_filename = ss.str();
  338. std::ofstream model_file(model_filename, std::ios::out | std::ios::binary);
  339. if(!model_file.good()) {
  340. LOGERROR << "Could not open " << model_filename << " for serializing the model!";
  341. }
  342. graph_->SerializeParameters(model_file);
  343. }
  344. bool Project::LoadSnapshot(std::string &file_name) {
  345. if(state != LOADED)
  346. return false;
  347. needs_rescore_ = true;
  348. std::ifstream model_file(file_name, std::ios::in | std::ios::binary);
  349. if(!model_file.good()) {
  350. LOGERROR << "Could not open " << file_name << " for deserializing the model!";
  351. return false;
  352. }
  353. graph_->DeserializeParameters(model_file);
  354. return true;
  355. }
  356. bool Project::SaveModel() {
  357. std::string model_filename = project_folder_ + "/model.CNParamX";
  358. std::ofstream model_file(model_filename, std::ios::out | std::ios::binary);
  359. if(!model_file.good()) {
  360. LOGERROR << "Could not open " << model_filename << " for serializing the model!";
  361. return false;
  362. }
  363. graph_->SerializeParameters(model_file);
  364. return true;
  365. }
  366. bool Project::LoadModel() {
  367. needs_rescore_ = true;
  368. std::string model_filename = project_folder_ + "/model.CNParamX";
  369. std::ifstream model_file(model_filename, std::ios::in | std::ios::binary);
  370. if(!model_file.good()) {
  371. LOGERROR << "Could not open " << model_filename << " for deserializing the model!";
  372. return false;
  373. }
  374. graph_->DeserializeParameters(model_file);
  375. return true;
  376. }
  377. bool Project::RenameClass(const std::string &org_name, const std::string new_name) {
  378. // Rename classes
  379. bool class_result = class_manager_->RenameClass(org_name, new_name);
  380. if(!class_result) {
  381. LOGERROR << "Could not rename class!";
  382. return false;
  383. }
  384. // Rename samples
  385. bool sample_result = known_samples_->RenameClass(org_name, new_name);
  386. sample_result &= new_set_->RenameClass(org_name, new_name);
  387. sample_result &= update_set_->RenameClass(org_name, new_name);
  388. if(!sample_result) {
  389. LOGERROR << "Could not rename class in samples!";
  390. SetState(FAILED);
  391. return false;
  392. }
  393. return true;
  394. }