Procházet zdrojové kódy

[fix] Use full batch size for segment scoring

Clemens-Alexander Brust před 9 roky
rodič
revize
5b993140a0
2 změnil soubory, kde provedl 18 přidání a 6 odebrání
  1. 13 6
      src/Project.cpp
  2. 5 0
      src/SelectScoredSegmentDialog.cpp

+ 13 - 6
src/Project.cpp

@@ -103,16 +103,23 @@ void Project::UpdateScores(ProjectProgressHandler *progress_handler) {
   Conv::datum total_sample_count = new_set_->GetSampleCount();
   Conv::datum running_sample_count = 0;
 
+  unsigned int batch_size = prediction_buffer.combined_tensor->data.samples();
+
   for(unsigned int s = 0; s < new_set_->GetSegmentCount(); s++) {
     Conv::Segment* segment = new_set_->GetSegment(s);
     Conv::datum segment_score = 0;
-    for(unsigned int sample = 0; sample < segment->GetSampleCount(); sample++) {
-      Conv::JSON& sample_json = segment->GetSample(sample);
-      input_layer_->ForceLoadDetection(sample_json, 0);
+
+    for(unsigned int sample = 0; sample < segment->GetSampleCount(); sample+= batch_size) {
+      for(unsigned int bindex = 0; bindex < batch_size && (sample+bindex) < (segment->GetSampleCount()); bindex++) {
+        Conv::JSON& sample_json = segment->GetSample(sample+bindex);
+        input_layer_->ForceLoadDetection(sample_json, bindex);
+      }
       graph_->FeedForward();
-      segment_score += policy->Score(prediction_buffer.combined_tensor->data, predicted_metadata, 0);
-      running_sample_count += 1.0;
-      progress_handler->OnProjectProgressUpdate(running_sample_count / total_sample_count);
+      for(unsigned int bindex = 0; bindex < batch_size && (sample+bindex) < (segment->GetSampleCount()); bindex++) {
+        segment_score += policy->Score(prediction_buffer.combined_tensor->data, predicted_metadata, bindex);
+        running_sample_count += 1.0;
+        progress_handler->OnProjectProgressUpdate(running_sample_count / total_sample_count);
+      }
     }
     segment->score = segment_score;
   }

+ 5 - 0
src/SelectScoredSegmentDialog.cpp

@@ -78,12 +78,14 @@ void SelectScoredSegmentDialog::OnProjectProgressUpdate(float progress) {
 void SelectScoredSegmentDialog::InnerOnProjectProgressDone() {
   ui->progressBar->setValue(0);
   ui->progressBar->setEnabled(false);
+  ui->updateScoreButton->setEnabled(true);
   UpdateSegments();
 }
 
 void SelectScoredSegmentDialog::InnerOnProjectProgressFailed() {
   ui->progressBar->setValue(0);
   ui->progressBar->setEnabled(false);
+  ui->updateScoreButton->setEnabled(true);
   UpdateSegments();
 }
 
@@ -93,7 +95,10 @@ void SelectScoredSegmentDialog::InnerOnProjectProgressUpdate(float progress) {
 
 void SelectScoredSegmentDialog::on_updateScoreButton_clicked()
 {
+  project->needs_rescore_ = true;
+
   ui->progressBar->setEnabled(true);
+  ui->updateScoreButton->setEnabled(false);
   QtConcurrent::run(project, &Project::UpdateScores, (ProjectProgressHandler*)this);
 }