|
- import json
- import os
- from PyQt5 import uic, QtWidgets, QtCore, QtGui
- from . import AboutDialog
- from .LabelDialog import LabelDialog
- from .NewProjectWizard import NewProjectWizard
- from ..project import Project
- from ..utils import Video
- class MainWindow:
- # Properties
- def _set_project(self, project):
- self.__project = project
- if project is not None:
- project._update_ui_fn = lambda: self._update_ui()
- self._update_ui()
- def _get_project(self):
- return self.__project
- _project = property(fget=_get_project, fset=_set_project)
- def _has_project(self):
- if self.__project is not None:
- return True
- else:
- return False
- def _set_predictions(self, predictions):
- self.__predictions = predictions
- self._update_prediction_ui()
- def _get_predictions(self):
- return self.__predictions
- _predictions = property(fget=_get_predictions, fset=_set_predictions)
- def _project_needs_saving(self):
- return True
- # Constructor
- def __init__(self, **kwargs):
- # initialize window
- self.ui = QtWidgets.QMainWindow(**kwargs)
- # set object properties
- self.__project = None
- self.__predictions = []
- # load ui from corresponding file
- spath = os.path.dirname(__file__)
- uic.loadUi(os.path.join(spath, 'MainWindow.ui'), self.ui)
- # prepare some more ui stuff
- self.ui.statusLabel = QtWidgets.QLabel(self.ui)
- self.ui.statusBar.addPermanentWidget(self.ui.statusLabel)
- self.ui.timerThread = QtCore.QThread(self.ui)
- self.ui.timer = QtCore.QTimer()
- self.ui.timer.setInterval(2000)
- self.ui.timer.moveToThread(self.ui.timerThread)
- # connect actions
- # File
- self.ui.actionNew.triggered.connect(self._project_new)
- self.ui.actionOpen.triggered.connect(self.project_open)
- self.ui.actionClose.triggered.connect(self._project_try_closing)
- self.ui.actionSave.triggered.connect(self._project_save)
- self.ui.actionQuit.triggered.connect(self._file_quit)
- # Labeling
- self.ui.actionStart_Labeling_Session.triggered.connect(self._start_labeling_session)
- self.ui.startLabelingSessionButton.clicked.connect(self._start_labeling_session)
- # Help
- self.ui.actionAbout.triggered.connect(self._help_about)
- self.ui.actionPredict_Images.triggered.connect(self.predict_via_dialog)
- self.ui.predictButton.clicked.connect(self.predict_via_dialog)
- self.ui.imageThumbnailGallery.currentRowChanged.connect(self._prediction_row_changed)
- self._update_ui()
- self._update_prediction_ui()
- # TODO delete me
- # self._project = Project('test-project')
- def _update_ui(self):
- # Status text
- if self._has_project():
- status_text = "OK"
- if self._project.last_error is not None:
- status_text = self._project.last_error
- else:
- status_text = "No project loaded."
- self.ui.pipelineStatusLabel.setText(status_text)
- # Action availability
- project_only_actions = [self.ui.actionSave,
- self.ui.actionSave_As,
- self.ui.actionClose,
- self.ui.actionImport_Labeled_Faces,
- self.ui.actionImport_Labeled_Images,
- self.ui.startLabelingSessionButton,
- self.ui.actionStart_Labeling_Session,
- self.ui.actionManage_Individuals,
- self.ui.actionPredict_Images,
- self.ui.predictButton,
- self.ui.actionUpdate_Model,
- self.ui.updateButton,
- self.ui.actionValidate_Model,
- self.ui.clearPredictionsButton,
- self.ui.exportCSVButton]
- for action in project_only_actions:
- action.setEnabled(self._has_project())
- def show(self):
- return self.ui.show()
- #####################
- # Project lifecycle #
- #####################
- def _project_try_closing(self):
- if self._has_project():
- if self._project_needs_saving():
- # Ask user
- box = QtWidgets.QMessageBox()
- box.setText('The current project has unsaved changes. Would you like to save them?')
- box.setStandardButtons(QtWidgets.QMessageBox.Yes | QtWidgets.QMessageBox.No | QtWidgets.QMessageBox.Cancel)
- retval = box.exec_()
- if retval == QtWidgets.QMessageBox.Yes:
- self._project.save()
- if self._project.last_error is not None:
- error_string = self._project.last_error
- QtWidgets.QMessageBox.warning(self.ui, 'Error', 'Error while saving (see log for details): %s'
- % error_string)
- self._project.reset_error()
- return False
- elif retval == QtWidgets.QMessageBox.No:
- pass
- elif retval == QtWidgets.QMessageBox.Cancel:
- return False
- self._project.close()
- self._project = None
- return True
- else:
- return True
- def _project_new(self):
- retval = self._project_try_closing()
- if retval:
- w = NewProjectWizard(self.ui)
- w_retval = w.exec_()
- if w_retval:
- self._project = Project(w.project_root, w.project_config)
- def project_open(self, path=False):
- """
- opens a project by either using the optional parameter path or
- or the value the user enters using the ui file dialog
- :param path: path to project.json
- """
- # check if there is an open project with unsaved changes
- retval = self._project_try_closing()
- if not retval:
- return
- # copy path to selection or question user if not given
- if path is not False:
- selection = path
- else:
- selection, _ = QtWidgets.QFileDialog.getOpenFileName(self.ui, 'Select Project Folder', filter='project.json')
- # check if path exists and open project
- if len(selection) > 0 and os.path.exists(selection):
- self._project = Project(selection)
- if self._project.last_error is not None:
- error_string = self._project.last_error
- try:
- QtWidgets.QMessageBox.warning(self.ui, 'Error', 'Error while opening (see log for details): %s'
- % error_string)
- self._project.close()
- except:
- pass
- self._project = None
- elif len(selection) > 0: # Project file does not exist
- QtWidgets.QMessageBox.warning(self.ui, 'Error', 'Could not find project file: %s' % selection)
- def _project_save(self):
- self._project.save()
- if self._project.last_error is not None:
- error_string = self._project.last_error
- QtWidgets.QMessageBox.warning(self.ui, 'Error', 'Error while saving (see log for details): %s'
- % error_string)
- self._project.reset_error()
- def _file_quit(self):
- # TODO: warning if there are any unsaved changes
- self.close()
- def _start_labeling_session(self):
- labeling = LabelDialog(self._predictions)
- labeling.exec_()
- def _help_about(self):
- about = AboutDialog()
- about.exec_()
- #####################################
- # Prediction actions and management #
- #####################################
- def predict_via_dialog(self, selection=False):
- """
- loads images from given paths or ui file dialog
- :param selection: array of image paths
- :return:
- """
- # valid image and video extensions
- image_extensions = ['.jpg', '.png', '.jpeg']
- video_extensions = ['.mp4', '.avi', '.mkv', '.mpg', '.mpeg']
- # ask user for input
- if selection is False:
- name_list = ' '.join(['*' + extension for extension in image_extensions + video_extensions])
- extensions = f'All supported files ({name_list})'
- selection, _ = QtWidgets.QFileDialog.getOpenFileNames(self.ui, 'Select Files To Add', filter=extensions)
- # load jobs
- jobs = []
- for filename in selection:
- _, extension = os.path.splitext(filename)
- extension = extension.lower()
- if extension in image_extensions:
- jobs.append({
- 'filename': filename,
- 'filetype': 'image',
- 'jobs': ['detect-faces']
- })
- elif extension in video_extensions:
- jobs.append({
- 'filename': filename,
- 'filetype': 'image',
- 'jobs': ['detect-faces']
- })
- progress_dialog = QtWidgets.QProgressDialog('Executing prediction jobs...', 'Abort', 0, 1000, self.ui)
- progress_dialog.setWindowModality(QtCore.Qt.WindowModal)
- def callback(progress):
- progress_dialog.setValue(int(progress * 1000))
- for i in range(10):
- QtWidgets.QApplication.processEvents()
- self._project.execute(jobs, callback=callback)
- progress_dialog.close()
- self._predictions += jobs
- def _update_prediction_ui(self):
- # TODO implement better sync
- self.ui.imageThumbnailGallery.clear()
- for prediction in self._predictions:
- filename = prediction['filename']
- if prediction['filetype'] == 'image':
- icon = QtGui.QIcon(filename)
- item = QtWidgets.QListWidgetItem(icon, os.path.basename(filename))
- else:
- item = QtWidgets.QListWidgetItem(os.path.basename(filename))
- self.ui.imageThumbnailGallery.addItem(item)
- self._prediction_row_changed(self.ui.imageThumbnailGallery.currentRow())
- if len(self._predictions) == 0:
- self._prediction_row_changed(-1)
- def _prediction_row_changed(self, row):
- if row in range(len(self._predictions)):
- prediction = self._predictions[row]
- self._setup_video_slider(prediction)
- self._display_prediction(prediction)
- else:
- self._setup_video_slider({'filetype': 'empty'})
- self._display_prediction({'filetype': 'empty'})
- def _setup_video_slider(self, prediction):
- if prediction['filetype'] == 'video':
- video_cap = Video(prediction['filename'])
- if video_cap.last_error is None:
- self.ui.frameNumber.setDecMode()
- self.ui.frameNumber.display(0)
- self.ui.frameNumber.setDigitCount(5)
- self.ui.frameNumber.show()
- self.ui.frameSlider.show()
- self.ui.frameSlider.setValue(0)
- self.ui.frameSlider.setMinimum(0)
- self.ui.frameSlider.setMaximum(video_cap.framecount - 1)
- self.ui.frameSlider.valueChanged.connect(
- lambda value, _prediction=prediction:
- [self._display_prediction(dict(frame=value, cap=video_cap, **_prediction)),
- self.ui.frameNumber.display(value)])
- else:
- self._disable_video_slider()
- else:
- self._disable_video_slider()
- def _disable_video_slider(self):
- self.ui.frameSlider.setValue(0)
- try:
- self.ui.frameSlider.valueChanged.disconnect()
- except:
- pass
- self.ui.frameSlider.hide()
- self.ui.frameNumber.hide()
- def _display_prediction(self, prediction):
- self.ui.annotatedImageView.display(prediction)
- filtered_dict = {i: prediction[i] for i in prediction if i != 'cap'}
- if 'cap' in prediction.keys():
- filtered_dict['cap'] = '[VideoCapture]'
- self.ui.predictionDebugText.setPlainText(json.dumps(filtered_dict, indent=2, skipkeys=True))
|