/* * This file is part of the Carpe Diem Active Learning Software, * Copyright (C) 2017 Clemens-Alexander Brust (ikosa dot de at gmail dot com). * * For licensing information, see the LICENSE file included with this project. */ #include #include #include #include #include #include #include #include #include #include "MainWindow.h" #include "ui_MainWindow.h" #include "CDUtils.h" #include "ProjectDetailsDialog.h" #include "ImportLabeledDataDialog.h" #include "SelectSegmentDialog.h" #include "SelectScoredSegmentDialog.h" #include "ProgressDialog.h" #include "LabelHypothesesDialog.h" #include "ClassManagerDialog.h" MainWindow::MainWindow(QWidget *parent) : QMainWindow(parent), ui(new Ui::MainWindow), project(this, this, this) { ui->setupUi(this); OnProjectStateUpdate(); } MainWindow::~MainWindow() { delete ui; } void MainWindow::on_actionNew_triggered() { ProjectDetailsDialog d(this); if(d.exec()) { std::string architecture_filename, model_filename, project_folder, project_name; d.GetOptions(architecture_filename, model_filename, project_folder); bool result = project.New(architecture_filename, model_filename, project_name, project_folder); if(!result) { QMessageBox::critical(this, tr("Error"), tr("Could not create new project! See log for details."), QMessageBox::Ok); } } } void MainWindow::OnClassUpdate() { // Do nothing here } void MainWindow::OnProjectStateUpdate() { // Buttons if(project.state == Project::LOADED) { ui->predictMultipleButton->setEnabled(true); if(project.known_samples_->GetSampleCount() > 0 && project.update_set_->GetSampleCount() > 0) { ui->updateModelButton->setEnabled(true); } else { ui->updateModelButton->setEnabled(false); } ui->importDataButton->setEnabled(true); ui->actionImport_New_Data->setEnabled(true); ui->actionImport_New_Data_Recursively->setEnabled(true); ui->actionImport_Labeled_Data->setEnabled(true); ui->actionReview_Known_Data->setEnabled(true); ui->actionReview_Labeled_Data->setEnabled(true); ui->actionReview_New_Data->setEnabled(true); if(project.new_set_->GetSampleCount() > 0) { ui->labelDataButton->setEnabled(true); } else { ui->labelDataButton->setEnabled(false); } ui->clearPredictionsButton->setEnabled(prediction_files_.size() > 0); ui->exportCSVButton->setEnabled(prediction_files_.size() > 0); ui->actionSave->setEnabled(true); ui->actionSave_As->setEnabled(true); ui->actionNew->setEnabled(false); ui->actionLoad->setEnabled(false); ui->actionSet_Active_Learning_Policy->setEnabled(true); ui->actionManage_Classes->setEnabled(true); ui->actionSave_Snapshot->setEnabled(true); ui->actionLoad_Snapshot->setEnabled(true); } else { ui->predictMultipleButton->setEnabled(false); ui->updateModelButton->setEnabled(false); ui->importDataButton->setEnabled(false); ui->actionImport_New_Data->setEnabled(false); ui->actionImport_New_Data_Recursively->setEnabled(false); ui->actionImport_Labeled_Data->setEnabled(false); ui->actionReview_Known_Data->setEnabled(false); ui->actionReview_Labeled_Data->setEnabled(false); ui->actionReview_New_Data->setEnabled(false); ui->labelDataButton->setEnabled(false); ui->clearPredictionsButton->setEnabled(false); ui->exportCSVButton->setEnabled(false); ui->actionSave->setEnabled(false); ui->actionSave_As->setEnabled(false); ui->actionNew->setEnabled(project.state == Project::NOTHING); ui->actionLoad->setEnabled(project.state == Project::NOTHING); ui->actionSet_Active_Learning_Policy->setEnabled(false); ui->actionManage_Classes->setEnabled(false); ui->actionSave_Snapshot->setEnabled(false); ui->actionLoad_Snapshot->setEnabled(false); } // Model status if(project.state == Project::LOADED) { std::stringstream ss; ss << "Model loaded." << std::endl; // Number of known images if(project.known_samples_->GetSampleCount() > 0) { ss << "Known examples: " << project.known_samples_->GetSampleCount() << std::endl; } else { ss << "You need to import the model's known samples before updating the model." << std::endl; } if(project.new_set_->GetSampleCount() > 0) { ss << "Unlabeled examples: " << project.new_set_->GetSampleCount() << std::endl; } else { ss << "No unlabeled examples, please import unlabeled data to get started." << std::endl; } if(project.update_set_->GetSampleCount() > 0) { ss << "Ready to update: " << project.update_set_->GetSampleCount(); } else { ss << "No labeled samples available, model cannot be updated."; } ui->modelStatusText->setText(QString::fromStdString(ss.str())); } else if(project.state == Project::NOTHING) { ui->modelStatusText->setText("Welcome to Carpe Diem!\nTo get started, please create a new project or load an existing project using the menu options."); } else if(project.state == Project::FAILED) { ui->modelStatusText->setText("Please restart Carpe Diem."); } // Window title if(project.state == Project::LOADED) { std::stringstream ss; ss << "Carpe Diem - " << project.project_name_; setWindowTitle(QString::fromStdString(ss.str())); } else { setWindowTitle("Carpe Diem"); } } void MainWindow::on_actionLoad_YOLO_Small_VOC_triggered() { #ifdef W_OS_LINUX bool result = project.New("/home/clemens/.cn24/yolo/base_all.json", "/home/clemens/.cn24/yolo/yolo-small.CNParamX", "YOLO-Small", "/home/clemens/.cn24/proj"); #else bool result = project.New("C:/Users/clemens/cn24/base_all.json", "C:/Users/clemens/cn24/yolo-small.CNParamX", "YOLO-Small", "C:/Users/clemens/cn24/proj"); #endif if(!result) { QMessageBox::critical(this, tr("Error"), tr("Could not create new project! See log for details."), QMessageBox::Ok); } } void MainWindow::on_predictMultipleButton_clicked() { QStringList fileNames = QFileDialog::getOpenFileNames(this, tr("Open Images to Predict"), QString(), tr("Image (*.jpg *.jpeg *.png)")); ui->progressBar->setValue(0); int old_list_count = ui->predictionList->count(); for(unsigned int i = 0; i < fileNames.size(); i++) { QString fileName = fileNames[i]; std::vector predictions; project.Predict(fileName.toStdString(), predictions); ui->predictionList->addItem(new QListWidgetItem(QIcon(fileName), CDUtils::FileBaseName(fileName))); prediction_boxes_.push_back(predictions); prediction_files_.push_back(fileName); int progress = (100 * (i + 1)) / fileNames.size(); ui->progressBar->setValue(progress); } if(fileNames.size() > 0) ui->predictionList->setCurrentRow(old_list_count); OnProjectStateUpdate(); ui->progressBar->setValue(100); ui->predictionList->focusWidget(); } void MainWindow::on_predictionList_currentItemChanged(QListWidgetItem *current, QListWidgetItem *previous) { if(project.state != Project::LOADED) return; if(current != nullptr) { int index = ui->predictionList->row(current); if(index >= 0 && index < prediction_boxes_.size()) { ui->imageDisplay->DisplayImage(prediction_files_[index].toStdString(), prediction_boxes_[index], project.class_manager_); } } } void MainWindow::on_clearPredictionsButton_clicked() { ui->predictionList->clear(); prediction_boxes_.clear(); prediction_files_.clear(); OnProjectStateUpdate(); ui->imageDisplay->ClearImage(); } void MainWindow::on_importLabeledDataButton_clicked() { auto fileName = QFileDialog::getOpenFileName(this, tr("Open Labeled Dataset"), QString(), tr("Dataset (*.json)")); if(fileName.length() > 0) { ImportLabeledDataDialog d(this, fileName.toStdString(), &project); d.exec(); OnProjectStateUpdate(); } } void MainWindow::on_actionSave_triggered() { bool result = project.Save(); if(!result) { QMessageBox::critical(this, tr("Error"), tr("Could not save project! See log for details."), QMessageBox::Ok); } } void MainWindow::on_actionLoad_triggered() { auto folderName = QFileDialog::getExistingDirectory(this, tr("Select Project Folder")); if(folderName.length() > 0) { bool result = project.Load(folderName.toStdString()); if(!result) { QMessageBox::critical(this, tr("Error"), tr("Could not load project! See log for details."), QMessageBox::Ok); } } } void MainWindow::on_importDataButton_clicked() { QStringList fileNames = QFileDialog::getOpenFileNames(this, tr("Open Images to Import As Unlabeled Data"), QString(), tr("Image (*.jpg *.jpeg *.png)")); std::mt19937 generator; ui->progressBar->setValue(0); int maxprogress = fileNames.size(); int i = 0; while(fileNames.size() > 0) { std::uniform_int_distribution dist(0, fileNames.size() - 1); int file_index = dist(generator); QString fileName = fileNames[file_index]; fileNames.removeAt(file_index); bool result = project.AddSample(fileName.toStdString()); if(!result) { QMessageBox::critical(this, tr("Error"), tr("Could not import unlabeled sample! See log for details."), QMessageBox::Ok); OnProjectStateUpdate(); break; } int progress = (100 * (i++ + 1)) / maxprogress; ui->progressBar->setValue(progress); } OnProjectStateUpdate(); ui->progressBar->setValue(100); } void MainWindow::on_actionImport_New_Data_Recursively_triggered() { auto folderName = QFileDialog::getExistingDirectory(this, tr("Select New Data Folder")); if(folderName.length() <= 0) return; QDirIterator iterator(folderName, QStringList() << "*.jpg" << "*.png", QDir::Files, QDirIterator::Subdirectories); QStringList fileNames; while(iterator.hasNext()) fileNames << iterator.next(); std::mt19937 generator; ui->progressBar->setValue(0); int maxprogress = fileNames.size(); int i = 0; while(fileNames.size() > 0) { std::uniform_int_distribution dist(0, fileNames.size() - 1); int file_index = dist(generator); QString fileName = fileNames[file_index]; fileNames.removeAt(file_index); bool result = project.AddSample(fileName.toStdString()); if(!result) { QMessageBox::critical(this, tr("Error"), tr("Could not import unlabeled sample! See log for details."), QMessageBox::Ok); OnProjectStateUpdate(); break; } int progress = (100 * (i++ + 1)) / maxprogress; ui->progressBar->setValue(progress); } OnProjectStateUpdate(); ui->progressBar->setValue(100); } void MainWindow::on_actionImport_New_Data_triggered() { on_importDataButton_clicked(); } void MainWindow::on_actionImport_Labeled_Data_triggered() { on_importLabeledDataButton_clicked(); } void MainWindow::AddSegmentImagesToList(Conv::Segment *segment) { if(segment != nullptr) { unsigned int sample_count = segment->GetSampleCount(); if(sample_count > 100) { QMessageBox::StandardButton result = QMessageBox::information(this, tr("Information"), tr("You have chosen more than 100 images for review. Do you want to review all images? Select No to review only the first 100."), QMessageBox::Yes, QMessageBox::No); if(result == QMessageBox::No) sample_count = 100; } for(unsigned int sample = 0; sample < sample_count; sample++) { Conv::JSON& sample_json = segment->GetSample(sample); std::vector predictions; Conv::Segment::CopyDetectionMetadata(sample_json, 1, 1, *(project.class_manager_), &predictions); for(unsigned int b = 0; b < predictions.size(); b++) { // Notify the view that these coordinates are not normalized predictions[b].flag1 = true; } QString fileName = QString::fromStdString(sample_json["image_rpath"]); ui->predictionList->addItem(new QListWidgetItem(QIcon(fileName), CDUtils::FileBaseName(fileName))); prediction_boxes_.push_back(predictions); prediction_files_.push_back(fileName); } } OnProjectStateUpdate(); } void MainWindow::on_actionReview_Known_Data_triggered() { SelectSegmentDialog d(this, project.known_samples_, &project); if(d.exec()) { Conv::Segment* segment = d.GetSegment(); AddSegmentImagesToList(segment); } OnProjectStateUpdate(); } void MainWindow::on_actionReview_Labeled_Data_triggered() { SelectSegmentDialog d(this, project.update_set_, &project); if(d.exec()) { Conv::Segment* segment = d.GetSegment(); AddSegmentImagesToList(segment); } OnProjectStateUpdate(); } void MainWindow::on_actionReview_New_Data_triggered() { SelectScoredSegmentDialog d(this, &project); if(d.exec()) { Conv::Segment* segment = d.GetSegment(); AddSegmentImagesToList(segment); } OnProjectStateUpdate(); } void MainWindow::on_updateModelButton_clicked() { QString description = tr("Updating model..."); ProgressDialog d(this, &project, description); QFuture future = QtConcurrent::run(&project, &Project::UpdateModel, (Project::ProjectProgressHandler*)&d); if(!d.exec(&future)) { QMessageBox::critical(this, tr("Error"), tr("Could not update model! See log for details."), QMessageBox::Ok); } else { // Move update segments to known segments while(project.update_set_->GetSegmentCount() > 0) { Conv::Segment* segment = project.update_set_->GetSegment(0); project.known_samples_->AddSegment(segment); project.update_set_->RemoveSegment(0); } project.SaveSnapshot(); } OnProjectStateUpdate(); } void MainWindow::on_labelDataButton_clicked() { QString description = tr("Selecting best batch from new data..."); ProgressDialog d(this, &project, description); QFuture future = QtConcurrent::run(&project, &Project::UpdateScores, (Project::ProjectProgressHandler*)&d); if(!d.exec(&future)) { QMessageBox::critical(this, tr("Error"), tr("Could not select best batch! See log for details."), QMessageBox::Ok); } else { // Segments are scored, select best segment Conv::Segment* best_segment =project.new_set_->GetSegment(0); Conv::datum best_score = 0; int best_segment_index = 0; for(unsigned int s = 0; s < project.new_set_->GetSegmentCount(); s++) { if(project.new_set_->GetSegment(s)->score > best_score) { best_score = project.new_set_->GetSegment(s)->score; best_segment = project.new_set_->GetSegment(s); best_segment_index = s; } } // Predict segment QString description2 = tr("Generating predictions on new data..."); ProgressDialog d2(this, &project, description2); std::vector> boxes; std::vector filenames; // Start prediction future QFuture future2 = QtConcurrent::run(&project, &Project::PredictSegment, (Project::ProjectProgressHandler*)&d2, best_segment, &filenames, &boxes); if(!d2.exec(&future2)) { QMessageBox::critical(this, tr("Error"), tr("Could not predict new samples! See log for details."), QMessageBox::Ok); } else { // Display labeling dialog LabelHypothesesDialog d3(this, boxes, filenames, project.class_manager_); if(d3.exec()) { if(d3.StoreInSegment(best_segment)) { project.new_set_->RemoveSegment(best_segment_index); project.update_set_->AddSegment(best_segment); } else { // JSON error QMessageBox::critical(this, tr("Error"), tr("Could not save labels! See log for details."), QMessageBox::Ok); } } else { // User interrupted labeling, stop QMessageBox::critical(this, tr("Error"), tr("Labeling was interrupted."), QMessageBox::Ok); } } } OnProjectStateUpdate(); } void MainWindow::on_actionSet_Active_Learning_Policy_triggered() { QStringList list({"wholeimagediff", "wholeimage1vs2", "proposalsum1vs2", "proposalmax1vs2","proposalavg1vs2"}); QString result = QInputDialog::getItem(this, tr("Select Active Learning Policy"), tr("Active Learning Policy:"), list, 0, false); if(result.length() > 0) { project.SetActiveLearningPolicy(result.toStdString()); OnProjectStateUpdate(); } } void MainWindow::on_actionSet_New_Sample_Batch_Size_triggered() { int result = QInputDialog::getInt(this, tr("Select New Sample Batch Size"), tr("New Sample Batch Size:"), 32, 1, 320); if(result > 0) { project.SetNewBatchSize(result); OnProjectStateUpdate(); } } void MainWindow::on_exportCSVButton_clicked() { QString targetFile = QFileDialog::getSaveFileName(this, tr("Select a CSV File"), QString(), tr("CSV File (*.csv)")); if(targetFile.length() > 0) { std::string filename = targetFile.toStdString(); std::ofstream csv_output(filename, std::ios::out); csv_output << "FileName;Class;CenterX;CenterY;Width;Height" << std::endl; ui->progressBar->setValue(0); for(unsigned int s = 0; s < prediction_boxes_.size(); s++) { std::vector& sample_boxes = prediction_boxes_[s]; QImage image(prediction_files_[s]); Conv::datum width = image.width(), height = image.height(); std::string filename = prediction_files_[s].toStdString(); // One CSV line per bounding box for(unsigned int b = 0; b < sample_boxes.size(); b++) { Conv::BoundingBox& box = sample_boxes[b]; csv_output << filename << ";"; csv_output << project.class_manager_->GetClassInfoById(box.c).first << ";"; if(box.flag1) { // Coordinates are not normalized, direct output csv_output << box.x << ";" << box.y << ";" << box.w << ";" << box.h; } else { // Coordinates are normalized, find out image proportions csv_output << box.x * width << ";" << box.y * height << ";" << box.w * width << ";" << box.h * height; } csv_output << std::endl; } ui->progressBar->setValue((100 * (s+1)) / prediction_boxes_.size()); } ui->progressBar->setValue(0); } } void MainWindow::on_actionAbout_Carpe_Diem_triggered() { QMessageBox::information(this, tr("About Carpe Diem"), tr("Carpe Diem Version 1.0.0-alpha1\n\nCopyright (C) 2017 Clemens-Alexander Brust\nE-Mail:ikosa.de@gmail.com"), QMessageBox::Ok); } void MainWindow::on_actionManage_Classes_triggered() { ClassManagerDialog d(this, project.class_manager_, &project); d.exec(); } void MainWindow::on_actionLoad_Snapshot_triggered() { QString fileName = QFileDialog::getOpenFileName(this, tr("Load Snapshot"), QString::fromStdString(project.project_folder_), tr("CN24 Model (*.CNParamX)")); if(fileName.length() > 0) { std::string file_name = fileName.toStdString(); if(!project.LoadSnapshot(file_name)) { QMessageBox::warning(this, tr("Error"), tr("Could not load snapshot! See log for details.")); } OnProjectStateUpdate(); } } void MainWindow::on_actionSave_Snapshot_triggered() { project.SaveSnapshot(); }