Sfoglia il codice sorgente

[feature] Class Manager: New Dialog

Clemens-Alexander Brust 9 anni fa
parent
commit
7e333df2e0
9 ha cambiato i file con 734 aggiunte e 502 eliminazioni
  1. 71 0
      src/ClassManagerDialog.cpp
  2. 36 0
      src/ClassManagerDialog.h
  3. 78 0
      src/ClassManagerDialog.ui
  4. 13 2
      src/MainWindow.cpp
  5. 2 0
      src/MainWindow.h
  6. 7 0
      src/MainWindow.ui
  7. 424 402
      src/Project.cpp
  8. 97 95
      src/Project.h
  9. 6 3
      src/carpediem.pro

+ 71 - 0
src/ClassManagerDialog.cpp

@@ -0,0 +1,71 @@
+#include "ClassManagerDialog.h"
+#include "ui_ClassManagerDialog.h"
+
+#include <sstream>
+#include <QMessageBox>
+#include <QInputDialog>
+#include <QDebug>
+
+ClassManagerDialog::ClassManagerDialog(QWidget *parent, Conv::ClassManager* class_manager, Project* project) :
+  QDialog(parent),
+  class_manager(class_manager),
+  project(project),
+  ui(new Ui::ClassManagerDialog)
+{
+  ui->setupUi(this);
+  UpdateClasses();
+}
+
+void ClassManagerDialog::UpdateClasses() {
+  ui->classList->clear();
+  for(Conv::ClassManager::const_iterator it = class_manager->begin(); it != class_manager->end(); it++) {
+    std::string class_name = it->first;
+    ui->classList->addItem(QString::fromStdString(class_name));
+  }
+
+  if(class_manager->GetClassCount() > 0) {
+    ui->renameButton->setEnabled(true);
+    ui->classList->item(0)->setSelected(true);
+    ui->classList->setCurrentRow(0);
+  } else {
+    ui->renameButton->setEnabled(false);
+  }
+}
+
+ClassManagerDialog::~ClassManagerDialog()
+{
+  delete ui;
+}
+
+void ClassManagerDialog::on_renameButton_clicked()
+{
+  if(current_class.length() > 0) {
+    std::stringstream ss;
+    ss << "Renaming the class means that you will not be able to use any model snapshots created before this date! ";
+    ss << "Do you want to rename the class \"" << current_class << "\"?";
+    QMessageBox::warning(this, tr("Warning"), QString::fromStdString(ss.str()), QMessageBox::Yes, QMessageBox::No);
+
+    // Get new class name
+    bool ok = false;
+    QString result = QInputDialog::getText(this, tr("New Class Name"), tr("New Class Name:"), QLineEdit::Normal, QString(), &ok);
+
+    if(ok && result.length() > 0) {
+      bool rename_result = project->RenameClass(current_class, result.toStdString());
+      if(!rename_result) {
+        QMessageBox::warning(this, tr("Error"), tr("Could not rename the class! Please see the log for details."));
+        reject();
+      } else {
+        UpdateClasses();
+      }
+    }
+  }
+}
+
+void ClassManagerDialog::on_classList_currentItemChanged(QListWidgetItem *current, QListWidgetItem *previous)
+{
+  if(current != nullptr) {
+    current_class = current->text().toStdString();
+  } else {
+    current_class = "";
+  }
+}

+ 36 - 0
src/ClassManagerDialog.h

@@ -0,0 +1,36 @@
+#ifndef CLASSMANAGERDIALOG_H
+#define CLASSMANAGERDIALOG_H
+
+#include <QDialog>
+#include <QListWidgetItem>
+
+#include "Project.h"
+#include <cn24.h>
+
+namespace Ui {
+class ClassManagerDialog;
+}
+
+class ClassManagerDialog : public QDialog
+{
+  Q_OBJECT
+
+public:
+  ClassManagerDialog(QWidget *parent, Conv::ClassManager* class_manager, Project* project);
+  ~ClassManagerDialog();
+
+  void UpdateClasses();
+
+private slots:
+  void on_renameButton_clicked();
+
+  void on_classList_currentItemChanged(QListWidgetItem *current, QListWidgetItem *previous);
+
+private:
+  std::string current_class;
+  Conv::ClassManager* class_manager;
+  Project* project;
+  Ui::ClassManagerDialog *ui;
+};
+
+#endif // CLASSMANAGERDIALOG_H

+ 78 - 0
src/ClassManagerDialog.ui

@@ -0,0 +1,78 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<ui version="4.0">
+ <class>ClassManagerDialog</class>
+ <widget class="QDialog" name="ClassManagerDialog">
+  <property name="geometry">
+   <rect>
+    <x>0</x>
+    <y>0</y>
+    <width>771</width>
+    <height>677</height>
+   </rect>
+  </property>
+  <property name="windowTitle">
+   <string>Manage Classes</string>
+  </property>
+  <layout class="QVBoxLayout" name="verticalLayout">
+   <item>
+    <widget class="QListWidget" name="classList"/>
+   </item>
+   <item>
+    <layout class="QHBoxLayout" name="horizontalLayout">
+     <item>
+      <widget class="QPushButton" name="renameButton">
+       <property name="text">
+        <string>Rename...</string>
+       </property>
+      </widget>
+     </item>
+     <item>
+      <widget class="QDialogButtonBox" name="buttonBox">
+       <property name="orientation">
+        <enum>Qt::Horizontal</enum>
+       </property>
+       <property name="standardButtons">
+        <set>QDialogButtonBox::Ok</set>
+       </property>
+      </widget>
+     </item>
+    </layout>
+   </item>
+  </layout>
+ </widget>
+ <resources/>
+ <connections>
+  <connection>
+   <sender>buttonBox</sender>
+   <signal>accepted()</signal>
+   <receiver>ClassManagerDialog</receiver>
+   <slot>accept()</slot>
+   <hints>
+    <hint type="sourcelabel">
+     <x>248</x>
+     <y>254</y>
+    </hint>
+    <hint type="destinationlabel">
+     <x>157</x>
+     <y>274</y>
+    </hint>
+   </hints>
+  </connection>
+  <connection>
+   <sender>buttonBox</sender>
+   <signal>rejected()</signal>
+   <receiver>ClassManagerDialog</receiver>
+   <slot>reject()</slot>
+   <hints>
+    <hint type="sourcelabel">
+     <x>316</x>
+     <y>260</y>
+    </hint>
+    <hint type="destinationlabel">
+     <x>286</x>
+     <y>274</y>
+    </hint>
+   </hints>
+  </connection>
+ </connections>
+</ui>

+ 13 - 2
src/MainWindow.cpp

@@ -24,6 +24,7 @@
 #include "SelectScoredSegmentDialog.h"
 #include "ProgressDialog.h"
 #include "LabelHypothesesDialog.h"
+#include "ClassManagerDialog.h"
 
 MainWindow::MainWindow(QWidget *parent) :
   QMainWindow(parent),
@@ -84,6 +85,7 @@ void MainWindow::OnProjectStateUpdate() {
     ui->actionNew->setEnabled(false);
     ui->actionLoad->setEnabled(false);
     ui->actionSet_Active_Learning_Policy->setEnabled(true);
+    ui->actionManage_Classes->setEnabled(true);
   } else {
     ui->predictMultipleButton->setEnabled(false);
     ui->updateModelButton->setEnabled(false);
@@ -101,6 +103,7 @@ void MainWindow::OnProjectStateUpdate() {
     ui->actionNew->setEnabled(project.state == Project::NOTHING);
     ui->actionLoad->setEnabled(project.state == Project::NOTHING);
     ui->actionSet_Active_Learning_Policy->setEnabled(false);
+    ui->actionManage_Classes->setEnabled(false);
   }
 
   // Model status
@@ -127,8 +130,10 @@ void MainWindow::OnProjectStateUpdate() {
     }
 
     ui->modelStatusText->setText(QString::fromStdString(ss.str()));
-  } else {
+  } else if(project.state == Project::NOTHING) {
     ui->modelStatusText->setText("Welcome to Carpe Diem!\nTo get started, please create a new project or load an existing project using the menu options.");
+  } else if(project.state == Project::FAILED) {
+    ui->modelStatusText->setText("Please restart Carpe Diem.");
   }
 
   // Window title
@@ -144,7 +149,7 @@ void MainWindow::OnProjectStateUpdate() {
 void MainWindow::on_actionLoad_YOLO_Small_VOC_triggered()
 {
 #ifdef W_OS_LINUX
-  bool result = project.New("/home/clemens/.cn24/yolo/yolo-small-aug.json", "/home/clemens/.cn24/yolo/yolo-small.CNParamX", "YOLO-Small", "/home/clemens/.cn24/proj");
+  bool result = project.New("/home/clemens/.cn24/yolo/base_all.json", "/home/clemens/.cn24/yolo/yolo-small.CNParamX", "YOLO-Small", "/home/clemens/.cn24/proj");
 #else
   bool result = project.New("C:/Users/clemens/cn24/base_all.json", "C:/Users/clemens/cn24/yolo-small.CNParamX", "YOLO-Small", "C:/Users/clemens/cn24/proj");
 #endif
@@ -458,3 +463,9 @@ void MainWindow::on_actionAbout_Carpe_Diem_triggered()
 {
   QMessageBox::information(this, tr("About Carpe Diem"), tr("Carpe Diem Version 1.0.0-alpha1\n\nCopyright (C) 2017 Clemens-Alexander Brust\nE-Mail:ikosa.de@gmail.com"), QMessageBox::Ok);
 }
+
+void MainWindow::on_actionManage_Classes_triggered()
+{
+  ClassManagerDialog d(this, project.class_manager_, &project);
+  d.exec();
+}

+ 2 - 0
src/MainWindow.h

@@ -70,6 +70,8 @@ private slots:
 
   void on_actionAbout_Carpe_Diem_triggered();
 
+  void on_actionManage_Classes_triggered();
+
 private:
   void on_importLabeledDataButton_clicked(); // not a slot anymore
   void AddSegmentImagesToList(Conv::Segment* segment);

+ 7 - 0
src/MainWindow.ui

@@ -115,6 +115,8 @@
     <addaction name="actionReview_Known_Data"/>
     <addaction name="actionReview_New_Data"/>
     <addaction name="actionReview_Labeled_Data"/>
+    <addaction name="separator"/>
+    <addaction name="actionManage_Classes"/>
    </widget>
    <widget class="QMenu" name="menuAdvanced">
     <property name="title">
@@ -284,6 +286,11 @@
     <string>About Carpe Diem...</string>
    </property>
   </action>
+  <action name="actionManage_Classes">
+   <property name="text">
+    <string>Manage Classes...</string>
+   </property>
+  </action>
   <zorder>dockWidget_2</zorder>
   <zorder>dockWidget_4</zorder>
  </widget>

+ 424 - 402
src/Project.cpp

@@ -5,405 +5,427 @@
  * For licensing information, see the LICENSE file included with this project.
  */
 
-#include "Project.h"
-
-#include <string>
-#include <cn24.h>
-
-Project::Project(QObject* _parent, Conv::ClassManager::ClassUpdateHandler* class_update_handler, ProjectStateHandler* project_state_handler)
-  : QObject(_parent), class_update_handler_(class_update_handler), state_handler_(project_state_handler), state(NOTHING)
-{
-
-}
-
-void Project::PredictSegment(ProjectProgressHandler *progress_handler, Conv::Segment *segment, std::vector<std::string> *prediction_filenames, std::vector<std::vector<Conv::BoundingBox> > *prediction_boxes) {
-  if(state != Project::LOADED)
-    return;
-
-  progress_handler->OnProjectProgressUpdate(0);
-
-  // Switch to testing mode
-  graph_->SetIsTesting(true);
-  input_layer_->ForceWeightsZero();
-
-  Conv::datum total_sample_count = segment->GetSampleCount();
-  Conv::datum running_sample_count = 0;
-
-  for(unsigned int sample = 0; sample < segment->GetSampleCount(); sample++) {
-    Conv::JSON& sample_json = segment->GetSample(sample);
-    std::vector<Conv::BoundingBox> sample_predictions;
-
-    // Load sample
-    input_layer_->ForceLoadDetection(sample_json, 0);
-    graph_->FeedForward();
-
-    // Copy predictions
-    Conv::DetectionMetadataPointer output_boxes = predicted_metadata_[0];
-    for(unsigned int b = 0; b < output_boxes->size(); b++) {
-      Conv::BoundingBox bbox = output_boxes->at(b);
-      sample_predictions.push_back(bbox);
-    }
-
-    // Store predictions
-    std::string sample_filename = sample_json["image_rpath"];
-    prediction_filenames->push_back(sample_filename);
-    prediction_boxes->push_back(sample_predictions);
-
-    running_sample_count += 1.0;
-    progress_handler->OnProjectProgressUpdate(running_sample_count / total_sample_count);
-  }
-
-  // Done
-  progress_handler->OnProjectProgressUpdate(1);
-  progress_handler->OnProjectProgressDone();
-}
-
-void Project::UpdateModel(ProjectProgressHandler *progress_handler) {
-  if(update_set_->GetSampleCount() == 0) {
-    progress_handler->OnProjectProgressDone();
-    return;
-  }
-  needs_rescore_ = true;
-  // Update input layer settings
-  input_layer_->training_sets_.clear();
-  input_layer_->training_weights_.clear();
-  input_layer_->training_sets_.push_back(known_samples_);
-  input_layer_->training_sets_.push_back(update_set_);
-  input_layer_->training_weights_.push_back(1);
-  input_layer_->training_weights_.push_back(1);
-  input_layer_->UpdateDatasets();
-
-  ProjectTrainerProgressHandler trainer_progress_handler(progress_handler);
-
-  progress_handler->OnProjectProgressUpdate(0);
-  trainer_->SetUpdateHandler(&trainer_progress_handler);
-  trainer_->settings()["epoch_iterations"] = 10 * update_set_->GetSampleCount();
-  trainer_->Train(1, false);
-
-  progress_handler->OnProjectProgressUpdate(1);
-  progress_handler->OnProjectProgressDone();
-}
-
-void Project::UpdateScores(ProjectProgressHandler *progress_handler) {
-  progress_handler->OnProjectProgressUpdate(0);
-  if(!needs_rescore_) {
-    progress_handler->OnProjectProgressDone();
-    return;
-  }
-
-  Conv::JSONNetGraphFactory netgraph_factory(architecture_json, 123123);
-  Conv::ActiveLearningPolicy* policy = Conv::YOLOActiveLearningPolicy::CreateWithName(active_learning_policy_,netgraph_factory.GetYOLOConfiguration());
-
-  Conv::NetGraphBuffer& prediction_buffer = graph_->GetOutputNodes()[0]->output_buffers[0];
-  Conv::DatasetMetadataPointer* predicted_metadata = prediction_buffer.combined_tensor->metadata;
-
-  input_layer_->ForceWeightsZero();
-  graph_->SetIsTesting(true);
-
-  Conv::datum total_sample_count = new_set_->GetSampleCount();
-  Conv::datum running_sample_count = 0;
-
-  for(unsigned int s = 0; s < new_set_->GetSegmentCount(); s++) {
-    Conv::Segment* segment = new_set_->GetSegment(s);
-    Conv::datum segment_score = 0;
-    for(unsigned int sample = 0; sample < segment->GetSampleCount(); sample++) {
-      Conv::JSON& sample_json = segment->GetSample(sample);
-      input_layer_->ForceLoadDetection(sample_json, 0);
-      graph_->FeedForward();
-      segment_score += policy->Score(prediction_buffer.combined_tensor->data, predicted_metadata, 0);
-      running_sample_count += 1.0;
-      progress_handler->OnProjectProgressUpdate(running_sample_count / total_sample_count);
-    }
-    segment->score = segment_score;
-  }
-
-  delete policy;
-  needs_rescore_ = false;
-  progress_handler->OnProjectProgressDone();
-}
-
-bool Project::Save() {
-  bool model_result = SaveModel();
-  if(!model_result)
-    return false;
-
-  Conv::JSON project_json = Serialize();
-  std::string project_filename = project_folder_ + "/project.json";
-  std::ofstream project_file(project_filename, std::ios::out);
-  if(!project_file.good()) {
-    LOGERROR << "Could not open " << project_filename << " for writing";
-    return false;
-  }
-
-  project_file << project_json.dump(2);
-  return true;
-}
-
-bool Project::Load(std::string project_folder) {
-  needs_rescore_ = true;
-  project_folder_ = project_folder;
-  std::string project_filename = project_folder_ + "/project.json";
-  std::ifstream project_file(project_filename, std::ios::in);
-  if(!project_file.good()) {
-    LOGERROR << "Could not open " << project_filename << " for reading";
-    return false;
-  }
-  Conv::JSON project_json = Conv::JSON::parse(project_file);
-  bool project_result = Deserialize(project_json);
-  if(!project_result) {
-    SetState(FAILED);
-    return false;
-  }
-
-  SetState(Project::LOADED);
-
-  // Load Model
-  bool model_result = LoadModel();
-  if(!model_result) {
-    SetState(FAILED);
-    return false;
-  }
-
-  return true;
-}
-
-bool Project::AddSample(std::string filename) {
-  needs_rescore_ = true;
-  Conv::Segment* target_segment = nullptr;
-
-  // If no new samples have been added, create first segment
-  if(new_set_->GetSegmentCount() == 0) {
-    std::stringstream ss; ss << "New Data Batch " << 1;
-    Conv::Segment* segment = new Conv::Segment(ss.str());
-    new_set_->AddSegment(segment);
-    target_segment = segment;
-  } else {
-    // Otherwise, get last segment and check if it has room for new sample
-    Conv::Segment* segment = new_set_->GetSegment(new_set_->GetSegmentCount() - 1);
-    if(segment->GetSampleCount() < new_batch_size_) {
-      target_segment = segment;
-    } else {
-      // No room, create new segment
-      std::stringstream ss; ss << "New Data Batch " << new_set_->GetSegmentCount() + 1;
-      Conv::Segment* segment = new Conv::Segment(ss.str());
-      new_set_->AddSegment(segment);
-      target_segment = segment;
-    }
-  }
-
-  Conv::JSON sample_json = Conv::JSON::object();
-  sample_json["image_filename"] = filename;
-  sample_json["boxes"] = Conv::JSON::array();
-
-  return target_segment->AddSample(sample_json);
-}
-
-void Project::Predict(std::string image_filename, std::vector<Conv::BoundingBox> &predictions) {
-  if(state != Project::LOADED)
-    return;
-
-  std::string found_path = Conv::PathFinder::FindPath(image_filename, "");
-  if(found_path.length() > 0) {
-    Conv::JSON sample_json = Conv::JSON::object();
-    sample_json["image_rpath"] = found_path;
-    sample_json["boxes"] = Conv::JSON::array();
-    input_layer_->ForceLoadDetection(sample_json, 0);
-
-    graph_->SetIsTesting(true);
-    graph_->FeedForward();
-
-    Conv::DetectionMetadataPointer output_boxes = predicted_metadata_[0];
-    LOGINFO << "Predicted " << output_boxes->size() << " boxes.";
-
-    for(unsigned int b = 0; b < output_boxes->size(); b++) {
-      Conv::BoundingBox bbox = output_boxes->at(b);
-      predictions.push_back(bbox);
-    }
-  } else {
-    LOGERROR << "Could not find " << image_filename << "!";
-  }
-}
-
-Conv::JSON Project::Serialize() {
-  Conv::JSON project_json = Conv::JSON::object();
-  project_json["architecture"] = architecture_json;
-  project_json["update_set"] = update_set_->Serialize();
-  project_json["known_set"] = known_samples_->Serialize();
-  project_json["new_set"] = new_set_->Serialize();
-  project_json["name"] = project_name_;
-  project_json["new_batch_size"] = new_batch_size_;
-  project_json["active_learning_policy"] = active_learning_policy_;
-  return project_json;
-}
-
-bool Project::Deserialize(Conv::JSON& project_json) {
-  if(state != Project::NOTHING) {
-    LOGERROR << "Already have a project!";
-    return false;
-  }
-
-  // Load JSON
-  architecture_filename_="_from_json_";
-  if(!project_json["architecture"].is_object()) {
-    LOGERROR << "Project JSON is missing architecture";
-    return false;
-  }
-  if(!project_json["update_set"].is_object() || !project_json["known_set"].is_object()) {
-    LOGERROR << "Project JSON is missing set informations!";
-    return false;
-  }
-  if(project_json.count("new_batch_size") == 1 && project_json["new_batch_size"].is_number()) {
-    new_batch_size_ = project_json["new_batch_size"];
-  }
-  if(project_json.count("active_learning_policy") == 1 && project_json["active_learning_policy"].is_string()) {
-    active_learning_policy_ = project_json["active_learning_policy"];
-  }
-
-  architecture_json = project_json["architecture"];
-
-  // Create class manager
-  class_manager_ = new Conv::ClassManager();
-  class_manager_->RegisterClassUpdateHandler(class_update_handler_);
-
-  // Load architecture
-  Conv::JSONNetGraphFactory netgraph_factory(architecture_json, 123123);
-  graph_ = new Conv::NetGraph();
-
-  // Create dataset input layer
-  unsigned int batch_size_parallel = 1;
-  if(netgraph_factory.GetHyperparameters().count("batch_size_parallel") == 1 && netgraph_factory.GetHyperparameters()["batch_size_parallel"].is_number()) {
-    batch_size_parallel = netgraph_factory.GetHyperparameters()["batch_size_parallel"];
-  }
-  input_layer_ = new Conv::SegmentSetInputLayer(netgraph_factory.GetDataInput(), Conv::DETECTION, class_manager_, batch_size_parallel, 123923);
-  Conv::NetGraphNode* input_node = new Conv::NetGraphNode(input_layer_);
-  input_node->is_input = true;
-
-  // Add other layers
-  graph_->AddNode(input_node);
-  bool result = netgraph_factory.AddLayers(*graph_, class_manager_, 23923);
-  if(!result) {
-    SetState(Project::FAILED);
-    LOGERROR << "Could not construct network!";
-    return false;
-  }
-  graph_->Initialize();
-  graph_->InitializeWeights(true);
-
-  // Set helper pointers
-  predicted_metadata_ = (Conv::DetectionMetadataPointer*) graph_->GetOutputNodes()[0]->output_buffers[0].combined_tensor->metadata;
-
-  // Set project properties
-  project_name_ = project_json["name"];
-
-  // Load trainer
-  trainer_ = new Conv::Trainer(*graph_, netgraph_factory.GetHyperparameters());
-
-  // Load samples
-  known_samples_ = new Conv::SegmentSet("Known Examples");
-  bool deserialization_result = known_samples_->Deserialize(project_json["known_set"]);
-  update_set_ = new Conv::SegmentSet("Update Set");
-  deserialization_result &= update_set_->Deserialize(project_json["update_set"]);
-  new_set_ = new Conv::SegmentSet("New Set");
-  deserialization_result &= new_set_->Deserialize(project_json["new_set"]);
-
-  if(!deserialization_result) {
-    LOGERROR << "SegmentSet deserialization failed! See log for details.";
-    return false;
-  }
-
-  return true;
-}
-
-bool Project::New(std::string architecture_filename, std::string model_filename, std::string project_name, std::string project_folder) {
-  needs_rescore_ = true;
-  if(state != Project::NOTHING) {
-    LOGERROR << "Already have a project!";
-    return false;
-  } else {
-    // Validate filenames
-    std::ifstream architecture_file(architecture_filename, std::ios::in);
-    if(!architecture_file.good()) {
-      LOGERROR << "Failed to open architecture!";
-      return false;
-    }
-
-    std::ifstream model_file(model_filename, std::ios::in | std::ios::binary);
-    if(!model_file.good()) {
-      LOGERROR << "Failed to open model!";
-    }
-
-    // Create class manager
-    class_manager_ = new Conv::ClassManager();
-    class_manager_->RegisterClassUpdateHandler(class_update_handler_);
-
-    // Load architecture
-    architecture_json = Conv::JSON::parse(architecture_file);
-    Conv::JSONNetGraphFactory netgraph_factory(architecture_json, 123123);
-    graph_ = new Conv::NetGraph();
-
-    // Create dataset input layer
-    unsigned int batch_size_parallel = 1;
-    if(netgraph_factory.GetHyperparameters().count("batch_size_parallel") == 1 && netgraph_factory.GetHyperparameters()["batch_size_parallel"].is_number()) {
-      batch_size_parallel = netgraph_factory.GetHyperparameters()["batch_size_parallel"];
-    }
-    input_layer_ = new Conv::SegmentSetInputLayer(netgraph_factory.GetDataInput(), Conv::DETECTION, class_manager_, batch_size_parallel, 123923);
-    Conv::NetGraphNode* input_node = new Conv::NetGraphNode(input_layer_);
-    input_node->is_input = true;
-
-    // Add other layers
-    graph_->AddNode(input_node);
-    bool result = netgraph_factory.AddLayers(*graph_, class_manager_, 23923);
-    if(!result) {
-      SetState(Project::FAILED);
-      LOGERROR << "Could not construct network!";
-      return false;
-    }
-    graph_->Initialize();
-    graph_->InitializeWeights(true);
-
-
-    // Load model
-    graph_->DeserializeParameters(model_file);
-
-    // Load trainer
-    trainer_ = new Conv::Trainer(*graph_, netgraph_factory.GetHyperparameters());
-
-    // Initialize segment sets
-    known_samples_ = new Conv::SegmentSet("Known Examples");
-    update_set_ = new Conv::SegmentSet("Update Set");
-    new_set_ = new Conv::SegmentSet("New Set");
-
-    // Set helper pointers
-    predicted_metadata_ = (Conv::DetectionMetadataPointer*) graph_->GetOutputNodes()[0]->output_buffers[0].combined_tensor->metadata;
-
-    // Set project properties
-    project_name_ = project_name;
-    project_folder_ = project_folder;
-
-    SetState(Project::LOADED);
-    return true;
-  }
-}
-
-bool Project::SaveModel() {
-  std::string model_filename = project_folder_ + "/model.CNParamX";
-  std::ofstream model_file(model_filename, std::ios::out | std::ios::binary);
-  if(!model_file.good()) {
-    LOGERROR << "Could not open " << model_filename << " for serializing the model!";
-    return false;
-  }
-  graph_->SerializeParameters(model_file);
-  return true;
-}
-
-bool Project::LoadModel() {
-  needs_rescore_ = true;
-  std::string model_filename = project_folder_ + "/model.CNParamX";
-  std::ifstream model_file(model_filename, std::ios::in | std::ios::binary);
-  if(!model_file.good()) {
-    LOGERROR << "Could not open " << model_filename << " for deserializing the model!";
-    return false;
-  }
-  graph_->DeserializeParameters(model_file);
-  return true;
-}
+#include "Project.h"
+
+#include <string>
+#include <cn24.h>
+
+Project::Project(QObject* _parent, Conv::ClassManager::ClassUpdateHandler* class_update_handler, ProjectStateHandler* project_state_handler)
+  : QObject(_parent), class_update_handler_(class_update_handler), state_handler_(project_state_handler), state(NOTHING)
+{
+
+}
+
+void Project::PredictSegment(ProjectProgressHandler *progress_handler, Conv::Segment *segment, std::vector<std::string> *prediction_filenames, std::vector<std::vector<Conv::BoundingBox> > *prediction_boxes) {
+  if(state != Project::LOADED)
+    return;
+
+  progress_handler->OnProjectProgressUpdate(0);
+
+  // Switch to testing mode
+  graph_->SetIsTesting(true);
+  input_layer_->ForceWeightsZero();
+
+  Conv::datum total_sample_count = segment->GetSampleCount();
+  Conv::datum running_sample_count = 0;
+
+  for(unsigned int sample = 0; sample < segment->GetSampleCount(); sample++) {
+    Conv::JSON& sample_json = segment->GetSample(sample);
+    std::vector<Conv::BoundingBox> sample_predictions;
+
+    // Load sample
+    input_layer_->ForceLoadDetection(sample_json, 0);
+    graph_->FeedForward();
+
+    // Copy predictions
+    Conv::DetectionMetadataPointer output_boxes = predicted_metadata_[0];
+    for(unsigned int b = 0; b < output_boxes->size(); b++) {
+      Conv::BoundingBox bbox = output_boxes->at(b);
+      sample_predictions.push_back(bbox);
+    }
+
+    // Store predictions
+    std::string sample_filename = sample_json["image_rpath"];
+    prediction_filenames->push_back(sample_filename);
+    prediction_boxes->push_back(sample_predictions);
+
+    running_sample_count += 1.0;
+    progress_handler->OnProjectProgressUpdate(running_sample_count / total_sample_count);
+  }
+
+  // Done
+  progress_handler->OnProjectProgressUpdate(1);
+  progress_handler->OnProjectProgressDone();
+}
+
+void Project::UpdateModel(ProjectProgressHandler *progress_handler) {
+  if(update_set_->GetSampleCount() == 0) {
+    progress_handler->OnProjectProgressDone();
+    return;
+  }
+  needs_rescore_ = true;
+  // Update input layer settings
+  input_layer_->training_sets_.clear();
+  input_layer_->training_weights_.clear();
+  input_layer_->training_sets_.push_back(known_samples_);
+  input_layer_->training_sets_.push_back(update_set_);
+  input_layer_->training_weights_.push_back(1);
+  input_layer_->training_weights_.push_back(1);
+  input_layer_->UpdateDatasets();
+
+  ProjectTrainerProgressHandler trainer_progress_handler(progress_handler);
+
+  progress_handler->OnProjectProgressUpdate(0);
+  trainer_->SetUpdateHandler(&trainer_progress_handler);
+  trainer_->settings()["epoch_iterations"] = 10 * update_set_->GetSampleCount();
+  trainer_->Train(1, false);
+
+  progress_handler->OnProjectProgressUpdate(1);
+  progress_handler->OnProjectProgressDone();
+}
+
+void Project::UpdateScores(ProjectProgressHandler *progress_handler) {
+  progress_handler->OnProjectProgressUpdate(0);
+  if(!needs_rescore_) {
+    progress_handler->OnProjectProgressDone();
+    return;
+  }
+
+  Conv::JSONNetGraphFactory netgraph_factory(architecture_json, 123123);
+  Conv::ActiveLearningPolicy* policy = Conv::YOLOActiveLearningPolicy::CreateWithName(active_learning_policy_,netgraph_factory.GetYOLOConfiguration());
+
+  Conv::NetGraphBuffer& prediction_buffer = graph_->GetOutputNodes()[0]->output_buffers[0];
+  Conv::DatasetMetadataPointer* predicted_metadata = prediction_buffer.combined_tensor->metadata;
+
+  input_layer_->ForceWeightsZero();
+  graph_->SetIsTesting(true);
+
+  Conv::datum total_sample_count = new_set_->GetSampleCount();
+  Conv::datum running_sample_count = 0;
+
+  for(unsigned int s = 0; s < new_set_->GetSegmentCount(); s++) {
+    Conv::Segment* segment = new_set_->GetSegment(s);
+    Conv::datum segment_score = 0;
+    for(unsigned int sample = 0; sample < segment->GetSampleCount(); sample++) {
+      Conv::JSON& sample_json = segment->GetSample(sample);
+      input_layer_->ForceLoadDetection(sample_json, 0);
+      graph_->FeedForward();
+      segment_score += policy->Score(prediction_buffer.combined_tensor->data, predicted_metadata, 0);
+      running_sample_count += 1.0;
+      progress_handler->OnProjectProgressUpdate(running_sample_count / total_sample_count);
+    }
+    segment->score = segment_score;
+  }
+
+  delete policy;
+  needs_rescore_ = false;
+  progress_handler->OnProjectProgressDone();
+}
+
+bool Project::Save() {
+  bool model_result = SaveModel();
+  if(!model_result)
+    return false;
+
+  Conv::JSON project_json = Serialize();
+  std::string project_filename = project_folder_ + "/project.json";
+  std::ofstream project_file(project_filename, std::ios::out);
+  if(!project_file.good()) {
+    LOGERROR << "Could not open " << project_filename << " for writing";
+    return false;
+  }
+
+  project_file << project_json.dump(2);
+  return true;
+}
+
+bool Project::Load(std::string project_folder) {
+  needs_rescore_ = true;
+  project_folder_ = project_folder;
+  std::string project_filename = project_folder_ + "/project.json";
+  std::ifstream project_file(project_filename, std::ios::in);
+  if(!project_file.good()) {
+    LOGERROR << "Could not open " << project_filename << " for reading";
+    return false;
+  }
+  Conv::JSON project_json = Conv::JSON::parse(project_file);
+  bool project_result = Deserialize(project_json);
+  if(!project_result) {
+    SetState(FAILED);
+    return false;
+  }
+
+  SetState(Project::LOADED);
+
+  // Load Model
+  bool model_result = LoadModel();
+  if(!model_result) {
+    SetState(FAILED);
+    return false;
+  }
+
+  return true;
+}
+
+bool Project::AddSample(std::string filename) {
+  needs_rescore_ = true;
+  Conv::Segment* target_segment = nullptr;
+
+  // If no new samples have been added, create first segment
+  if(new_set_->GetSegmentCount() == 0) {
+    std::stringstream ss; ss << "New Data Batch " << 1;
+    Conv::Segment* segment = new Conv::Segment(ss.str());
+    new_set_->AddSegment(segment);
+    target_segment = segment;
+  } else {
+    // Otherwise, get last segment and check if it has room for new sample
+    Conv::Segment* segment = new_set_->GetSegment(new_set_->GetSegmentCount() - 1);
+    if(segment->GetSampleCount() < new_batch_size_) {
+      target_segment = segment;
+    } else {
+      // No room, create new segment
+      std::stringstream ss; ss << "New Data Batch " << new_set_->GetSegmentCount() + 1;
+      Conv::Segment* segment = new Conv::Segment(ss.str());
+      new_set_->AddSegment(segment);
+      target_segment = segment;
+    }
+  }
+
+  Conv::JSON sample_json = Conv::JSON::object();
+  sample_json["image_filename"] = filename;
+  sample_json["boxes"] = Conv::JSON::array();
+
+  return target_segment->AddSample(sample_json);
+}
+
+void Project::Predict(std::string image_filename, std::vector<Conv::BoundingBox> &predictions) {
+  if(state != Project::LOADED)
+    return;
+
+  std::string found_path = Conv::PathFinder::FindPath(image_filename, "");
+  if(found_path.length() > 0) {
+    Conv::JSON sample_json = Conv::JSON::object();
+    sample_json["image_rpath"] = found_path;
+    sample_json["boxes"] = Conv::JSON::array();
+    input_layer_->ForceWeightsZero();
+    input_layer_->ForceLoadDetection(sample_json, 0);
+
+    graph_->SetIsTesting(true);
+    graph_->FeedForward();
+
+    Conv::DetectionMetadataPointer output_boxes = predicted_metadata_[0];
+    LOGINFO << "Predicted " << output_boxes->size() << " boxes.";
+
+    for(unsigned int b = 0; b < output_boxes->size(); b++) {
+      Conv::BoundingBox bbox = output_boxes->at(b);
+      predictions.push_back(bbox);
+    }
+  } else {
+    LOGERROR << "Could not find " << image_filename << "!";
+  }
+}
+
+Conv::JSON Project::Serialize() {
+  Conv::JSON project_json = Conv::JSON::object();
+  project_json["architecture"] = architecture_json;
+  project_json["update_set"] = update_set_->Serialize();
+  project_json["known_set"] = known_samples_->Serialize();
+  project_json["new_set"] = new_set_->Serialize();
+  project_json["name"] = project_name_;
+  project_json["new_batch_size"] = new_batch_size_;
+  project_json["active_learning_policy"] = active_learning_policy_;
+  return project_json;
+}
+
+bool Project::Deserialize(Conv::JSON& project_json) {
+  if(state != Project::NOTHING) {
+    LOGERROR << "Already have a project!";
+    return false;
+  }
+
+  // Load JSON
+  architecture_filename_="_from_json_";
+  if(!project_json["architecture"].is_object()) {
+    LOGERROR << "Project JSON is missing architecture";
+    return false;
+  }
+  if(!project_json["update_set"].is_object() || !project_json["known_set"].is_object()) {
+    LOGERROR << "Project JSON is missing set informations!";
+    return false;
+  }
+  if(project_json.count("new_batch_size") == 1 && project_json["new_batch_size"].is_number()) {
+    new_batch_size_ = project_json["new_batch_size"];
+  }
+  if(project_json.count("active_learning_policy") == 1 && project_json["active_learning_policy"].is_string()) {
+    active_learning_policy_ = project_json["active_learning_policy"];
+  }
+
+  architecture_json = project_json["architecture"];
+
+  // Create class manager
+  class_manager_ = new Conv::ClassManager();
+  class_manager_->RegisterClassUpdateHandler(class_update_handler_);
+
+  // Load architecture
+  Conv::JSONNetGraphFactory netgraph_factory(architecture_json, 123123);
+  graph_ = new Conv::NetGraph();
+
+  // Create dataset input layer
+  unsigned int batch_size_parallel = 1;
+  if(netgraph_factory.GetHyperparameters().count("batch_size_parallel") == 1 && netgraph_factory.GetHyperparameters()["batch_size_parallel"].is_number()) {
+    batch_size_parallel = netgraph_factory.GetHyperparameters()["batch_size_parallel"];
+  }
+  input_layer_ = new Conv::SegmentSetInputLayer(netgraph_factory.GetDataInput(), Conv::DETECTION, class_manager_, batch_size_parallel, 123923);
+  Conv::NetGraphNode* input_node = new Conv::NetGraphNode(input_layer_);
+  input_node->is_input = true;
+
+  // Add other layers
+  graph_->AddNode(input_node);
+  bool result = netgraph_factory.AddLayers(*graph_, class_manager_, 23923);
+  if(!result) {
+    SetState(Project::FAILED);
+    LOGERROR << "Could not construct network!";
+    return false;
+  }
+  graph_->Initialize();
+  graph_->InitializeWeights(true);
+
+  // Set helper pointers
+  predicted_metadata_ = (Conv::DetectionMetadataPointer*) graph_->GetOutputNodes()[0]->output_buffers[0].combined_tensor->metadata;
+
+  // Set project properties
+  project_name_ = project_json["name"];
+
+  // Load trainer
+  trainer_ = new Conv::Trainer(*graph_, netgraph_factory.GetHyperparameters());
+
+  // Load samples
+  known_samples_ = new Conv::SegmentSet("Known Examples");
+  bool deserialization_result = known_samples_->Deserialize(project_json["known_set"]);
+  update_set_ = new Conv::SegmentSet("Update Set");
+  deserialization_result &= update_set_->Deserialize(project_json["update_set"]);
+  new_set_ = new Conv::SegmentSet("New Set");
+  deserialization_result &= new_set_->Deserialize(project_json["new_set"]);
+
+  if(!deserialization_result) {
+    LOGERROR << "SegmentSet deserialization failed! See log for details.";
+    return false;
+  }
+
+  return true;
+}
+
+bool Project::New(std::string architecture_filename, std::string model_filename, std::string project_name, std::string project_folder) {
+  needs_rescore_ = true;
+  if(state != Project::NOTHING) {
+    LOGERROR << "Already have a project!";
+    return false;
+  } else {
+    // Validate filenames
+    std::ifstream architecture_file(architecture_filename, std::ios::in);
+    if(!architecture_file.good()) {
+      LOGERROR << "Failed to open architecture!";
+      return false;
+    }
+
+    std::ifstream model_file(model_filename, std::ios::in | std::ios::binary);
+    if(!model_file.good()) {
+      LOGERROR << "Failed to open model!";
+    }
+
+    // Create class manager
+    class_manager_ = new Conv::ClassManager();
+    class_manager_->RegisterClassUpdateHandler(class_update_handler_);
+
+    // Load architecture
+    architecture_json = Conv::JSON::parse(architecture_file);
+    Conv::JSONNetGraphFactory netgraph_factory(architecture_json, 123123);
+    graph_ = new Conv::NetGraph();
+
+    // Create dataset input layer
+    unsigned int batch_size_parallel = 1;
+    if(netgraph_factory.GetHyperparameters().count("batch_size_parallel") == 1 && netgraph_factory.GetHyperparameters()["batch_size_parallel"].is_number()) {
+      batch_size_parallel = netgraph_factory.GetHyperparameters()["batch_size_parallel"];
+    }
+    input_layer_ = new Conv::SegmentSetInputLayer(netgraph_factory.GetDataInput(), Conv::DETECTION, class_manager_, batch_size_parallel, 123923);
+    Conv::NetGraphNode* input_node = new Conv::NetGraphNode(input_layer_);
+    input_node->is_input = true;
+
+    // Add other layers
+    graph_->AddNode(input_node);
+    bool result = netgraph_factory.AddLayers(*graph_, class_manager_, 23923);
+    if(!result) {
+      SetState(Project::FAILED);
+      LOGERROR << "Could not construct network!";
+      return false;
+    }
+    graph_->Initialize();
+    graph_->InitializeWeights(true);
+
+
+    // Load model
+    graph_->DeserializeParameters(model_file);
+
+    // Load trainer
+    trainer_ = new Conv::Trainer(*graph_, netgraph_factory.GetHyperparameters());
+
+    // Initialize segment sets
+    known_samples_ = new Conv::SegmentSet("Known Examples");
+    update_set_ = new Conv::SegmentSet("Update Set");
+    new_set_ = new Conv::SegmentSet("New Set");
+
+    // Set helper pointers
+    predicted_metadata_ = (Conv::DetectionMetadataPointer*) graph_->GetOutputNodes()[0]->output_buffers[0].combined_tensor->metadata;
+
+    // Set project properties
+    project_name_ = project_name;
+    project_folder_ = project_folder;
+
+    SetState(Project::LOADED);
+    return true;
+  }
+}
+
+bool Project::SaveModel() {
+  std::string model_filename = project_folder_ + "/model.CNParamX";
+  std::ofstream model_file(model_filename, std::ios::out | std::ios::binary);
+  if(!model_file.good()) {
+    LOGERROR << "Could not open " << model_filename << " for serializing the model!";
+    return false;
+  }
+  graph_->SerializeParameters(model_file);
+  return true;
+}
+
+bool Project::LoadModel() {
+  needs_rescore_ = true;
+  std::string model_filename = project_folder_ + "/model.CNParamX";
+  std::ifstream model_file(model_filename, std::ios::in | std::ios::binary);
+  if(!model_file.good()) {
+    LOGERROR << "Could not open " << model_filename << " for deserializing the model!";
+    return false;
+  }
+  graph_->DeserializeParameters(model_file);
+  return true;
+}
+
+bool Project::RenameClass(const std::string &org_name, const std::string new_name) {
+  // Rename classes
+  bool class_result = class_manager_->RenameClass(org_name, new_name);
+  if(!class_result) {
+    LOGERROR << "Could not rename class!";
+    return false;
+  }
+
+  // Rename samples
+  bool sample_result = known_samples_->RenameClass(org_name, new_name);
+  sample_result &= new_set_->RenameClass(org_name, new_name);
+  sample_result &= update_set_->RenameClass(org_name, new_name);
+  if(!sample_result) {
+    LOGERROR << "Could not rename class in samples!";
+    SetState(FAILED);
+    return false;
+  }
+
+  return true;
+}

+ 97 - 95
src/Project.h

@@ -5,98 +5,100 @@
  * For licensing information, see the LICENSE file included with this project.
  */
 
-#ifndef PROJECT_H
-#define PROJECT_H
-
-#include <QObject>
-#include <string>
-#include <cn24.h>
-
-class MainWindow;
-class ImportLabeledDataDialog;
-class SelectScoredSegmentDialog;
-class Project : public QObject
-{
-  friend class MainWindow;
-  friend class ImportLabeledDataDialog;
-  friend class SelectScoredSegmentDialog;
-  Q_OBJECT
-public:
-  class ProjectStateHandler {
-  public: virtual void OnProjectStateUpdate() = 0;
-  };
-
-  class ProjectProgressHandler {
-  public: virtual void OnProjectProgressUpdate(float progress) = 0;
-    virtual void OnProjectProgressDone() = 0;
-    virtual void OnProjectProgressFailed() = 0;
-  };
-
-  class ProjectTrainerProgressHandler: public Conv::TrainerProgressUpdateHandler {
-  public:
-    explicit ProjectTrainerProgressHandler(ProjectProgressHandler* handler) : handler(handler) {}
-    ProjectProgressHandler* handler = nullptr;
-    void OnTrainerProgressUpdate(Conv::datum progress) {
-      if(handler != nullptr)
-        handler->OnProjectProgressUpdate((float)progress);
-    }
-  };
-
-  void OnTrainerProgressUpdate(Conv::datum progress);
-
-  Project(QObject *_parent, Conv::ClassManager::ClassUpdateHandler* class_update_handler, ProjectStateHandler* state_handler);
-
-  bool New(std::string architecture_filename, std::string model_filename, std::string project_name, std::string project_folder);
-
-  bool Save();
-  bool Load(std::string project_folder);
-
-  void Predict(std::string image_filename, std::vector<Conv::BoundingBox>& predictions);
-  void PredictSegment(ProjectProgressHandler* progress_handler, Conv::Segment* segment, std::vector<std::string>* prediction_filenames, std::vector<std::vector<Conv::BoundingBox>>* prediction_boxes);
-
-  bool AddSample(std::string filename);
-
-  void UpdateScores(ProjectProgressHandler* progress_handler);
-  void UpdateModel(ProjectProgressHandler* progress_handler);
-
-  enum State {
-    NOTHING,
-    LOADED,
-    FAILED
-  } state;
-
-  void SetNewBatchSize(unsigned int new_batch_size) { needs_rescore_=true; new_batch_size_ = new_batch_size; }
-  void SetActiveLearningPolicy(std::string active_learning_policy) { needs_rescore_=true; active_learning_policy_ = active_learning_policy; }
-private:
-  bool Deserialize(Conv::JSON& project_json);
-  Conv::JSON Serialize();
-  bool SaveModel();
-  bool LoadModel();
-  void SetState(State state) { this->state = state; state_handler_->OnProjectStateUpdate(); }
-
-  unsigned int new_batch_size_ = 32;
-  std::string active_learning_policy_ = "wholeimagediff";
-
-  Conv::JSON architecture_json;
-  std::string architecture_filename_;
-  std::string model_filename_;
-  std::string project_name_;
-  std::string project_folder_;
-
-  Conv::SegmentSetInputLayer* input_layer_;
-  Conv::NetGraph* graph_;
-  Conv::Trainer* trainer_;
-  Conv::DetectionMetadataPointer* predicted_metadata_;
-
-  Conv::ClassManager::ClassUpdateHandler* class_update_handler_;
-  Conv::ClassManager* class_manager_;
-  ProjectStateHandler* state_handler_;
-
-  Conv::SegmentSet* known_samples_;
-  Conv::SegmentSet* update_set_;
-  Conv::SegmentSet* new_set_;
-
-  bool needs_rescore_ = true;
-};
-
-#endif // PROJECT_H
+#ifndef PROJECT_H
+#define PROJECT_H
+
+#include <QObject>
+#include <string>
+#include <cn24.h>
+
+class MainWindow;
+class ImportLabeledDataDialog;
+class SelectScoredSegmentDialog;
+class Project : public QObject
+{
+  friend class MainWindow;
+  friend class ImportLabeledDataDialog;
+  friend class SelectScoredSegmentDialog;
+  Q_OBJECT
+public:
+  class ProjectStateHandler {
+  public: virtual void OnProjectStateUpdate() = 0;
+  };
+
+  class ProjectProgressHandler {
+  public: virtual void OnProjectProgressUpdate(float progress) = 0;
+    virtual void OnProjectProgressDone() = 0;
+    virtual void OnProjectProgressFailed() = 0;
+  };
+
+  class ProjectTrainerProgressHandler: public Conv::TrainerProgressUpdateHandler {
+  public:
+    explicit ProjectTrainerProgressHandler(ProjectProgressHandler* handler) : handler(handler) {}
+    ProjectProgressHandler* handler = nullptr;
+    void OnTrainerProgressUpdate(Conv::datum progress) {
+      if(handler != nullptr)
+        handler->OnProjectProgressUpdate((float)progress);
+    }
+  };
+
+  void OnTrainerProgressUpdate(Conv::datum progress);
+
+  Project(QObject *_parent, Conv::ClassManager::ClassUpdateHandler* class_update_handler, ProjectStateHandler* state_handler);
+
+  bool New(std::string architecture_filename, std::string model_filename, std::string project_name, std::string project_folder);
+
+  bool Save();
+  bool Load(std::string project_folder);
+
+  void Predict(std::string image_filename, std::vector<Conv::BoundingBox>& predictions);
+  void PredictSegment(ProjectProgressHandler* progress_handler, Conv::Segment* segment, std::vector<std::string>* prediction_filenames, std::vector<std::vector<Conv::BoundingBox>>* prediction_boxes);
+
+  bool AddSample(std::string filename);
+
+  void UpdateScores(ProjectProgressHandler* progress_handler);
+  void UpdateModel(ProjectProgressHandler* progress_handler);
+
+  enum State {
+    NOTHING,
+    LOADED,
+    FAILED
+  } state;
+
+  void SetNewBatchSize(unsigned int new_batch_size) { needs_rescore_=true; new_batch_size_ = new_batch_size; }
+  void SetActiveLearningPolicy(std::string active_learning_policy) { needs_rescore_=true; active_learning_policy_ = active_learning_policy; }
+
+  bool RenameClass(const std::string &org_name, const std::string new_name);
+private:
+  bool Deserialize(Conv::JSON& project_json);
+  Conv::JSON Serialize();
+  bool SaveModel();
+  bool LoadModel();
+  void SetState(State state) { this->state = state; state_handler_->OnProjectStateUpdate(); }
+
+  unsigned int new_batch_size_ = 32;
+  std::string active_learning_policy_ = "wholeimagediff";
+
+  Conv::JSON architecture_json;
+  std::string architecture_filename_;
+  std::string model_filename_;
+  std::string project_name_;
+  std::string project_folder_;
+
+  Conv::SegmentSetInputLayer* input_layer_;
+  Conv::NetGraph* graph_;
+  Conv::Trainer* trainer_;
+  Conv::DetectionMetadataPointer* predicted_metadata_;
+
+  Conv::ClassManager::ClassUpdateHandler* class_update_handler_;
+  Conv::ClassManager* class_manager_;
+  ProjectStateHandler* state_handler_;
+
+  Conv::SegmentSet* known_samples_;
+  Conv::SegmentSet* update_set_;
+  Conv::SegmentSet* new_set_;
+
+  bool needs_rescore_ = true;
+};
+
+#endif // PROJECT_H

+ 6 - 3
src/carpediem.pro

@@ -21,7 +21,8 @@ SOURCES += main.cpp\
     SelectSegmentDialog.cpp \
     SelectScoredSegmentDialog.cpp \
     ProgressDialog.cpp \
-    LabelHypothesesDialog.cpp
+    LabelHypothesesDialog.cpp \
+    ClassManagerDialog.cpp
 
 HEADERS  += MainWindow.h \
     ProjectDetailsDialog.h \
@@ -32,7 +33,8 @@ HEADERS  += MainWindow.h \
     SelectSegmentDialog.h \
     SelectScoredSegmentDialog.h \
     ProgressDialog.h \
-    LabelHypothesesDialog.h
+    LabelHypothesesDialog.h \
+    ClassManagerDialog.h
 
 FORMS    += MainWindow.ui \
     ProjectDetailsDialog.ui \
@@ -40,7 +42,8 @@ FORMS    += MainWindow.ui \
     SelectSegmentDialog.ui \
     SelectScoredSegmentDialog.ui \
     ProgressDialog.ui \
-    LabelHypothesesDialog.ui
+    LabelHypothesesDialog.ui \
+    ClassManagerDialog.ui
 
 win32:CONFIG(release, debug|release): LIBS += -lcn24 -lOpenCL -lclBLAS -ljpeg-static
 else:win32:CONFIG(debug, debug|release): LIBS += -lcn24 -lOpenCL -lclBLAS -ljpeg-static