Prechádzať zdrojové kódy

[feature] Load and save snapshots

Clemens-Alexander Brust 9 rokov pred
rodič
commit
11177ea4df
5 zmenil súbory, kde vykonal 75 pridanie a 0 odobranie
  1. 21 0
      src/MainWindow.cpp
  2. 4 0
      src/MainWindow.h
  3. 13 0
      src/MainWindow.ui
  4. 34 0
      src/Project.cpp
  5. 3 0
      src/Project.h

+ 21 - 0
src/MainWindow.cpp

@@ -86,6 +86,8 @@ void MainWindow::OnProjectStateUpdate() {
     ui->actionLoad->setEnabled(false);
     ui->actionSet_Active_Learning_Policy->setEnabled(true);
     ui->actionManage_Classes->setEnabled(true);
+    ui->actionSave_Snapshot->setEnabled(true);
+    ui->actionLoad_Snapshot->setEnabled(true);
   } else {
     ui->predictMultipleButton->setEnabled(false);
     ui->updateModelButton->setEnabled(false);
@@ -104,6 +106,8 @@ void MainWindow::OnProjectStateUpdate() {
     ui->actionLoad->setEnabled(project.state == Project::NOTHING);
     ui->actionSet_Active_Learning_Policy->setEnabled(false);
     ui->actionManage_Classes->setEnabled(false);
+    ui->actionSave_Snapshot->setEnabled(false);
+    ui->actionLoad_Snapshot->setEnabled(false);
   }
 
   // Model status
@@ -469,3 +473,20 @@ void MainWindow::on_actionManage_Classes_triggered()
   ClassManagerDialog d(this, project.class_manager_, &project);
   d.exec();
 }
+
+void MainWindow::on_actionLoad_Snapshot_triggered()
+{
+  QString fileName = QFileDialog::getOpenFileName(this, tr("Load Snapshot"), QString::fromStdString(project.project_folder_), tr("CN24 Model (*.CNParamX)"));
+  if(fileName.length() > 0) {
+    std::string file_name = fileName.toStdString();
+    if(!project.LoadSnapshot(file_name)) {
+      QMessageBox::warning(this, tr("Error"), tr("Could not load snapshot! See log for details."));
+    }
+    OnProjectStateUpdate();
+  }
+}
+
+void MainWindow::on_actionSave_Snapshot_triggered()
+{
+  project.SaveSnapshot();
+}

+ 4 - 0
src/MainWindow.h

@@ -72,6 +72,10 @@ private slots:
 
   void on_actionManage_Classes_triggered();
 
+  void on_actionLoad_Snapshot_triggered();
+
+  void on_actionSave_Snapshot_triggered();
+
 private:
   void on_importLabeledDataButton_clicked(); // not a slot anymore
   void AddSegmentImagesToList(Conv::Segment* segment);

+ 13 - 0
src/MainWindow.ui

@@ -124,6 +124,9 @@
     </property>
     <addaction name="actionSet_Active_Learning_Policy"/>
     <addaction name="actionSet_New_Sample_Batch_Size"/>
+    <addaction name="separator"/>
+    <addaction name="actionLoad_Snapshot"/>
+    <addaction name="actionSave_Snapshot"/>
    </widget>
    <widget class="QMenu" name="menuHelp">
     <property name="title">
@@ -291,6 +294,16 @@
     <string>Manage Classes...</string>
    </property>
   </action>
+  <action name="actionLoad_Snapshot">
+   <property name="text">
+    <string>Load Snapshot...</string>
+   </property>
+  </action>
+  <action name="actionSave_Snapshot">
+   <property name="text">
+    <string>Save Snapshot</string>
+   </property>
+  </action>
   <zorder>dockWidget_2</zorder>
   <zorder>dockWidget_4</zorder>
  </widget>

+ 34 - 0
src/Project.cpp

@@ -8,6 +8,10 @@
 #include "Project.h"
 
 #include <string>
+#include <sstream>
+#include <chrono>
+#include <iomanip>
+#include <ctime>
 #include <cn24.h>
 
 Project::Project(QObject* _parent, Conv::ClassManager::ClassUpdateHandler* class_update_handler, ProjectStateHandler* project_state_handler)
@@ -393,6 +397,36 @@ bool Project::New(std::string architecture_filename, std::string model_filename,
   }
 }
 
+void Project::SaveSnapshot() {
+  if(state != LOADED)
+    return;
+
+  std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
+  std::time_t now_c = std::chrono::system_clock::to_time_t(now);
+  std::stringstream ss;
+  ss << project_folder_ << "/" << "snapshot-" << std::put_time(std::localtime(&now_c), "%y%m%d-%H%M") << "-" << known_samples_->GetSampleCount() << ".CNParamX";
+  std::string model_filename = ss.str();
+  std::ofstream model_file(model_filename, std::ios::out | std::ios::binary);
+  if(!model_file.good()) {
+    LOGERROR << "Could not open " << model_filename << " for serializing the model!";
+  }
+  graph_->SerializeParameters(model_file);
+}
+
+bool Project::LoadSnapshot(std::string &file_name) {
+  if(state != LOADED)
+    return false;
+
+  needs_rescore_ = true;
+  std::ifstream model_file(file_name, std::ios::in | std::ios::binary);
+  if(!model_file.good()) {
+    LOGERROR << "Could not open " << file_name << " for deserializing the model!";
+    return false;
+  }
+  graph_->DeserializeParameters(model_file);
+  return true;
+}
+
 bool Project::SaveModel() {
   std::string model_filename = project_folder_ + "/model.CNParamX";
   std::ofstream model_file(model_filename, std::ios::out | std::ios::binary);

+ 3 - 0
src/Project.h

@@ -69,6 +69,9 @@ public:
   void SetActiveLearningPolicy(std::string active_learning_policy) { needs_rescore_=true; active_learning_policy_ = active_learning_policy; }
 
   bool RenameClass(const std::string &org_name, const std::string new_name);
+
+  void SaveSnapshot();
+  bool LoadSnapshot(std::string& file_name);
 private:
   bool Deserialize(Conv::JSON& project_json);
   Conv::JSON Serialize();