Clemens-Alexander Brust 7 lat temu
rodzic
commit
e2f97cbc95
6 zmienionych plików z 121 dodań i 16 usunięć
  1. 58 9
      src/AnnotatedImageView.cpp
  2. 9 0
      src/AnnotatedImageView.h
  3. 6 2
      src/MainWindow.cpp
  4. 44 3
      src/Project.cpp
  5. 2 1
      src/Project.h
  6. 2 1
      src/carpediem.pro

+ 58 - 9
src/AnnotatedImageView.cpp

@@ -10,12 +10,20 @@
 #include <QGraphicsSimpleTextItem>
 #include <QPixmap>
 #include <QImage>
+#include <QMenu>
+#include <QFileDialog>
 
 AnnotatedImageView::AnnotatedImageView(QWidget* parent) : QGraphicsView(parent)
 {
   scene_ = new QGraphicsScene(this);
   setScene(scene_);
   rect_ = QRectF(0, 0, 0, 0);
+
+  actionExportImage = new QAction(tr("Export Annotated Image..."), this);
+  actionExportImage->setEnabled(false);
+  connect(actionExportImage, &QAction::triggered, this, &AnnotatedImageView::exportImage);
+
+  setRenderHint(QPainter::Antialiasing);
 }
 
 void AnnotatedImageView::ClearImage() {
@@ -25,6 +33,7 @@ void AnnotatedImageView::ClearImage() {
     image_ = nullptr;
   }
   rect_ = QRectF(0, 0, 0, 0);
+  actionExportImage->setEnabled(false);
 }
 
 void AnnotatedImageView::DisplayImage(std::string filename, std::vector<Conv::BoundingBox> bounding_boxes, Conv::ClassManager *class_manager) {
@@ -62,30 +71,46 @@ void AnnotatedImageView::DisplayImage(std::string filename, std::vector<Conv::Bo
 
     // Klassenbezeichnung ausgeben
     QString class_name = QString::fromStdString(class_manager->GetClassInfoById(box.c).first);
+
+
+
     auto* text_item = new QGraphicsSimpleTextItem(class_name);
     if(image_->width() >= 800) {
       text_item->setFont(QFont(tr("Helvetica"), 24));
     } else {
-      text_item->setFont(QFont(tr("Helvetica"), 16));
+      text_item->setFont(QFont(tr("Helvetica"), 32));
     }
-    text_item->setPos(box.x, box.y + box.h);
+    text_item->setPos(box.x + 10, box.y - 3);// + box.h);
     qreal text_width = text_item->boundingRect().width();
     qreal text_height = text_item->boundingRect().height();
-    if(box.y + box.h + text_height > y2) y2 = box.y + box.h + text_height;
+    text_item->setPos(box.x + 10, box.y - text_height - 3);// + box.h);
+    // if(box.y + box.h + text_height > y2) y2 = box.y + box.h + text_height;
+
+    QColor class_background = QColor::fromRgbF(1,1,1,0.9);
+    QColor box_background = QColor::fromRgbF(1,1,1,0.2);
+    QColor box_border = QColor::fromRgbF(0,0,0,1);
+
+    if(box.c == UNKNOWN_CLASS) {
+      class_background = QColor::fromRgbF(1,0.5,0.5,0.5);
+      box_background = QColor::fromRgbF(1,0.5,0.5,0.1);
+      box_border = QColor::fromRgbF(0.5,0,0,0.5);
+    }
 
     // Hintergrund der Klassenbezeichnung
-    scene_->addRect(box.x, box.y + box.h, text_width, text_height, QPen(), QBrush(QColor::fromRgbF(1,1,1,0.9)));
+    scene_->addRect(box.x, box.y - text_height, text_width + 20, text_height, QPen(), QBrush(QColor::fromRgbF(1,1,1,0.9)));
     scene_->addItem(text_item);
 
     // Bounding Box zeichnen
-    scene_->addRect(box.x, box.y, box.w, box.h,QPen(QBrush(QColor::fromRgbF(0,0,0,1)), 3.0), QBrush(QColor::fromRgbF(1,1,1,0.2)));
-
+    scene_->addRect(box.x, box.y, box.w, box.h,QPen(QBrush(box_border), 3.0), QBrush(box_background));
   }
 
-  if(image_->width() >= 800) {
+  const qreal border = 85;
+  QRectF bounding_rect(x1 - border, y1 - border, (2.0 * border) + x2 - x1, (2.0 * border) + y2 - y1);
+  QRectF noborder_rect(x1, y1, x2 - x1, y2 - y1);
+  scene_->setSceneRect(noborder_rect);
+
+  if(width >= 800) {
     // Großes Bild, Umgebungsrechteck für passende Anzeige feststellen
-    const qreal border = 85;
-    QRectF bounding_rect(x1 - border, y1 - border, (2.0 * border) + x2 - x1, (2.0 * border) + y2 - y1);
     rect_ = bounding_rect;
     RefitDisplay();
   } else {
@@ -93,6 +118,7 @@ void AnnotatedImageView::DisplayImage(std::string filename, std::vector<Conv::Bo
     rect_ = QRectF(0, 0, 0, 0);
     RefitDisplay();
   }
+  actionExportImage->setEnabled(true);
 }
 
 void AnnotatedImageView::RefitDisplay() {
@@ -102,3 +128,26 @@ void AnnotatedImageView::RefitDisplay() {
     this->fitInView(rect_, Qt::KeepAspectRatio);
   }
 }
+
+void AnnotatedImageView::contextMenuEvent(QContextMenuEvent *event) {
+  QMenu menu(this);
+  menu.addAction(actionExportImage);
+  menu.exec(event->globalPos());
+}
+
+void AnnotatedImageView::exportImage() {
+  // Calculate image bounds
+  auto rect = scene_->sceneRect();
+  QImage image(rect.width(), rect.height(), QImage::Format_RGBA8888);
+
+  // Render image
+  QPainter painter(&image);
+  painter.setRenderHint(QPainter::Antialiasing);
+  scene_->render(&painter);
+
+  // Get file name
+  QString fileName = QFileDialog::getSaveFileName(this, tr("Export Annotated Image"), QString(), tr("Image (*.jpg *.jpeg *.png)"));
+
+  if(fileName.length() > 0)
+    image.save(fileName);
+}

+ 9 - 0
src/AnnotatedImageView.h

@@ -14,6 +14,8 @@
 #include <QImage>
 #include <QRectF>
 #include <QResizeEvent>
+#include <QAction>
+#include <QContextMenuEvent>
 
 #include <cn24.h>
 
@@ -47,8 +49,15 @@ public:
    */
   void RefitDisplay();
 
+protected:
+  void contextMenuEvent(QContextMenuEvent *event);
+
+private slots:
+  void exportImage();
+
 private:
   void resizeEvent(QResizeEvent* evt) { RefitDisplay(); QGraphicsView::resizeEvent(evt); }
+  QAction* actionExportImage = nullptr;
   QGraphicsScene* scene_ = nullptr;
   QGraphicsPixmapItem* item_ = nullptr;
   QImage* image_ = nullptr;

+ 6 - 2
src/MainWindow.cpp

@@ -453,10 +453,14 @@ void MainWindow::on_labelDataButton_clicked()
 
 void MainWindow::on_actionSet_Active_Learning_Policy_triggered()
 {
-  QStringList list({"wholeimagediff", "wholeimage1vs2", "proposalsum1vs2", "proposalmax1vs2","proposalavg1vs2"});
+  QStringList list({"wholeimagediff", "wholeimage1vs2", "proposalsum1vs2", "wproposalsum1vs2", "proposalmax1vs2","proposalavg1vs2"});
   QString result = QInputDialog::getItem(this, tr("Select Active Learning Policy"), tr("Active Learning Policy:"), list, 0, false);
   if(result.length() > 0) {
-    project.SetActiveLearningPolicy(result.toStdString());
+    if(result.compare(tr("wproposalsum1vs2"))==0) {
+      project.SetActiveLearningPolicy("proposalsum1vs2", true);
+    } else {
+      project.SetActiveLearningPolicy(result.toStdString(), false);
+    }
     OnProjectStateUpdate();
   }
 }

+ 44 - 3
src/Project.cpp

@@ -100,7 +100,7 @@ void Project::UpdateScores(ProjectProgressHandler *progress_handler) {
   }
 
   Conv::JSONNetGraphFactory netgraph_factory(architecture_json, 123123);
-  Conv::ActiveLearningPolicy* policy = Conv::YOLOActiveLearningPolicy::CreateWithName(active_learning_policy_,netgraph_factory.GetYOLOConfiguration());
+  Conv::ActiveLearningPolicy* policy = Conv::YOLOActiveLearningPolicy::CreateWithName(active_learning_policy_,netgraph_factory.GetYOLOConfiguration(), 0);
 
   Conv::NetGraphBuffer& prediction_buffer = graph_->GetOutputNodes()[0]->output_buffers[0];
   Conv::DatasetMetadataPointer* predicted_metadata = prediction_buffer.combined_tensor->metadata;
@@ -108,6 +108,33 @@ void Project::UpdateScores(ProjectProgressHandler *progress_handler) {
   input_layer_->ForceWeightsZero();
   graph_->SetIsTesting(true);
 
+  // Klassengewichte ausrechnen
+  const unsigned int class_vector_len = class_manager_->GetMaxClassId() + 1;
+  std::vector<unsigned int> class_counts = std::vector<unsigned int>(class_vector_len);
+  for(unsigned int c = 0; c < class_vector_len; c++)
+      class_counts[c] = 0;
+  unsigned int total_labeled_sample_count = 0;
+  unsigned int total_labeled_class_count = class_manager_->GetClassCount();
+
+  if(active_learning_weights_) {
+
+    const unsigned int training_set_sample_count = known_samples_->GetSampleCount();
+    for(unsigned int s = 0; s < training_set_sample_count; s++) {
+      Conv::JSON sample = known_samples_->GetSample(s);
+      Conv::DetectionMetadata metadata;
+      Conv::Segment::CopyDetectionMetadata(sample, 0, 0, *class_manager_, &metadata);
+      for(Conv::BoundingBox& box : metadata) {
+        const unsigned int class_index = box.c;
+        if(class_index <= class_manager_->GetMaxClassId()) {
+          class_counts[class_index] += 1;
+          total_labeled_sample_count += 1;
+        }
+      }
+    }
+
+
+  }
+
   Conv::datum total_sample_count = new_set_->GetSampleCount();
   Conv::datum running_sample_count = 0;
 
@@ -124,7 +151,16 @@ void Project::UpdateScores(ProjectProgressHandler *progress_handler) {
       }
       graph_->FeedForward();
       for(unsigned int bindex = 0; bindex < batch_size && (sample+bindex) < (segment->GetSampleCount()); bindex++) {
-        segment_score += policy->Score(prediction_buffer.combined_tensor->data, predicted_metadata, bindex);
+        if(active_learning_weights_) {
+          std::vector<Conv::datum> class_weights = std::vector<Conv::datum>(class_vector_len);
+          for(unsigned int c = 0; c < class_vector_len; c++) {
+            class_weights[c] = (total_labeled_sample_count + (Conv::datum)total_labeled_class_count) /
+                (class_counts[c] + (Conv::datum)1.0);
+          }
+          segment_score += policy->ScoreW(prediction_buffer.combined_tensor->data, predicted_metadata, bindex, class_weights);
+        } else {
+          segment_score += policy->Score(prediction_buffer.combined_tensor->data, predicted_metadata, bindex);
+        }
         running_sample_count += 1.0;
         progress_handler->OnProjectProgressUpdate(running_sample_count / total_sample_count);
       }
@@ -249,6 +285,7 @@ Conv::JSON Project::Serialize() {
   project_json["name"] = project_name_;
   project_json["new_batch_size"] = new_batch_size_;
   project_json["active_learning_policy"] = active_learning_policy_;
+  project_json["active_learning_weights"] = active_learning_weights_;
   return project_json;
 }
 
@@ -274,6 +311,9 @@ bool Project::Deserialize(Conv::JSON& project_json) {
   if(project_json.count("active_learning_policy") == 1 && project_json["active_learning_policy"].is_string()) {
     active_learning_policy_ = project_json["active_learning_policy"];
   }
+  if(project_json.count("active_learning_weights") == 1 && project_json["active_learning_weights"].is_number()) {
+    active_learning_weights_ = project_json["active_learning_weights"];
+  }
 
   architecture_json = project_json["architecture"];
 
@@ -408,7 +448,8 @@ void Project::SaveSnapshot() {
   std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
   std::time_t now_c = std::chrono::system_clock::to_time_t(now);
   std::stringstream ss;
-  ss << project_folder_ << "/" << "snapshot-" << std::put_time(std::localtime(&now_c), "%y%m%d-%H%M") << "-" << known_samples_->GetSampleCount() << ".CNParamX";
+  //ss << project_folder_ << "/" << "snapshot-" << std::put_time(std::localtime(&now_c), "%y%m%d-%H%M") << "-" << known_samples_->GetSampleCount() << ".CNParamX";
+  ss << project_folder_ << "/" << "snapshot-" << "-" << known_samples_->GetSampleCount() << ".CNParamX";
   std::string model_filename = ss.str();
   std::ofstream model_file(model_filename, std::ios::out | std::ios::binary);
   if(!model_file.good()) {

+ 2 - 1
src/Project.h

@@ -162,7 +162,7 @@ public:
    * \brief Ändert das Bewertungskriterium für Lernstapel
    * \param active_learning_policy Neues Bewertungskriterium
    */
-  void SetActiveLearningPolicy(std::string active_learning_policy) { needs_rescore_=true; active_learning_policy_ = active_learning_policy; }
+  void SetActiveLearningPolicy(std::string active_learning_policy, bool active_learning_weights) { needs_rescore_=true; active_learning_policy_ = active_learning_policy; active_learning_weights_ = active_learning_weights; }
 
   /*!
    * \brief Benennt bestehende Klasse um
@@ -192,6 +192,7 @@ private:
 
   unsigned int new_batch_size_ = 32;
   std::string active_learning_policy_ = "wholeimagediff";
+  bool active_learning_weights_ = false;
 
   Conv::JSON architecture_json;
   std::string architecture_filename_;

+ 2 - 1
src/carpediem.pro

@@ -47,6 +47,7 @@ FORMS    += MainWindow.ui \
 
 win32:CONFIG(release, debug|release): LIBS += -lcn24 -lOpenCL -lclBLAS -ljpeg-static
 else:win32:CONFIG(debug, debug|release): LIBS += -lcn24 -lOpenCL -lclBLAS -ljpeg-static
-else:unix: LIBS += -lcn24
+else:unix: LIBS += /home/brust/local/cn24-active/libcn24.so
+INCLUDEPATH += /home/brust/repos/cn24-active/include
 
 CONFIG += c++11