6
0

MainWindow.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. import copy
  2. import json
  3. import os
  4. from PyQt5 import uic, QtWidgets, QtCore, QtGui
  5. from . import AboutDialog
  6. from .NewProjectWizard import NewProjectWizard
  7. from ..project import Project
  8. from ..utils import Video
  9. class MainWindow:
  10. # Properties
  11. def _set_project(self, project):
  12. self.__project = project
  13. if project is not None:
  14. project._update_ui_fn = lambda: self._update_ui()
  15. self._update_ui()
  16. def _get_project(self):
  17. return self.__project
  18. _project = property(fget=_get_project, fset=_set_project)
  19. def _has_project(self):
  20. if self.__project is not None:
  21. return True
  22. else:
  23. return False
  24. def _set_predictions(self, predictions):
  25. self.__predictions = predictions
  26. self._update_prediction_ui()
  27. def _get_predictions(self):
  28. return self.__predictions
  29. _predictions = property(fget=_get_predictions, fset=_set_predictions)
  30. def _project_needs_saving(self):
  31. return True
  32. # Constructor
  33. def __init__(self, **kwargs):
  34. # initialize window
  35. self.ui = QtWidgets.QMainWindow(**kwargs)
  36. # set object properties
  37. self.__project = None
  38. self.__predictions = []
  39. # load ui from corresponding file
  40. spath = os.path.dirname(__file__)
  41. uic.loadUi(os.path.join(spath, 'MainWindow.ui'), self.ui)
  42. # prepare some more ui stuff
  43. self.ui.statusLabel = QtWidgets.QLabel(self.ui)
  44. self.ui.statusBar.addPermanentWidget(self.ui.statusLabel)
  45. self.ui.timerThread = QtCore.QThread(self.ui)
  46. self.ui.timer = QtCore.QTimer()
  47. self.ui.timer.setInterval(2000)
  48. self.ui.timer.moveToThread(self.ui.timerThread)
  49. # connect actions
  50. # File
  51. self.ui.actionNew.triggered.connect(self._project_new)
  52. self.ui.actionOpen.triggered.connect(self.project_open)
  53. self.ui.actionClose.triggered.connect(self._project_try_closing)
  54. self.ui.actionSave.triggered.connect(self._project_save)
  55. self.ui.actionQuit.triggered.connect(self._file_quit)
  56. # Help
  57. self.ui.actionAbout.triggered.connect(self._help_about)
  58. self.ui.actionPredict_Images.triggered.connect(self._predict_via_dialog)
  59. self.ui.predictButton.clicked.connect(self._predict_via_dialog)
  60. self.ui.imageThumbnailGallery.currentRowChanged.connect(self._prediction_row_changed)
  61. self._update_ui()
  62. self._update_prediction_ui()
  63. # TODO delete me
  64. # self._project = Project('test-project')
  65. def _update_ui(self):
  66. # Status text
  67. if self._has_project():
  68. status_text = "OK"
  69. if self._project.last_error is not None:
  70. status_text = self._project.last_error
  71. else:
  72. status_text = "No project loaded."
  73. self.ui.pipelineStatusLabel.setText(status_text)
  74. # Action availability
  75. project_only_actions = [self.ui.actionSave,
  76. self.ui.actionSave_As,
  77. self.ui.actionClose,
  78. self.ui.actionImport_Labeled_Faces,
  79. self.ui.actionImport_Labeled_Images,
  80. self.ui.startLabelingSessionButton,
  81. self.ui.actionStart_Labeling_Session,
  82. self.ui.actionManage_Individuals,
  83. self.ui.actionPredict_Images,
  84. self.ui.predictButton,
  85. self.ui.actionUpdate_Model,
  86. self.ui.updateButton,
  87. self.ui.actionValidate_Model,
  88. self.ui.clearPredictionsButton,
  89. self.ui.exportCSVButton]
  90. for action in project_only_actions:
  91. action.setEnabled(self._has_project())
  92. def show(self):
  93. return self.ui.show()
  94. #####################
  95. # Project lifecycle #
  96. #####################
  97. def _project_try_closing(self):
  98. if self._has_project():
  99. if self._project_needs_saving():
  100. # Ask user
  101. box = QtWidgets.QMessageBox()
  102. box.setText('The current project has unsaved changes. Would you like to save them?')
  103. box.setStandardButtons(QtWidgets.QMessageBox.Yes | QtWidgets.QMessageBox.No | QtWidgets.QMessageBox.Cancel)
  104. retval = box.exec_()
  105. if retval == QtWidgets.QMessageBox.Yes:
  106. self._project.save()
  107. if self._project.last_error is not None:
  108. error_string = self._project.last_error
  109. QtWidgets.QMessageBox.warning(self.ui, 'Error', 'Error while saving (see log for details): %s'
  110. % error_string)
  111. self._project.reset_error()
  112. return False
  113. elif retval == QtWidgets.QMessageBox.No:
  114. pass
  115. elif retval == QtWidgets.QMessageBox.Cancel:
  116. return False
  117. self._project.close()
  118. self._project = None
  119. return True
  120. else:
  121. return True
  122. def _project_new(self):
  123. retval = self._project_try_closing()
  124. if retval:
  125. w = NewProjectWizard(self.ui)
  126. w_retval = w.exec_()
  127. if w_retval:
  128. self._project = Project(w.project_root, w.project_config)
  129. def project_open(self, path=None):
  130. """
  131. opens a project by either using the optional parameter path or
  132. or the value the user enters using the ui file dialog
  133. :param path: path to project.json
  134. """
  135. # check if there is an open project with unsaved changes
  136. retval = self._project_try_closing()
  137. if not retval:
  138. return
  139. # copy path to selection or question user if not given
  140. if path is not None and path != False:
  141. selection = path
  142. else:
  143. selection, _ = QtWidgets.QFileDialog.getOpenFileName(self.ui, 'Select Project Folder', filter='project.json')
  144. # check if path exists and open project
  145. if len(selection) > 0 and os.path.exists(selection):
  146. self._project = Project(selection)
  147. if self._project.last_error is not None:
  148. error_string = self._project.last_error
  149. try:
  150. QtWidgets.QMessageBox.warning(self.ui, 'Error', 'Error while opening (see log for details): %s'
  151. % error_string)
  152. self._project.close()
  153. except:
  154. pass
  155. self._project = None
  156. elif len(selection) > 0: # Project file does not exist
  157. QtWidgets.QMessageBox.warning(self.ui, 'Error', 'Could not find project file: %s' % selection)
  158. def _project_save(self):
  159. self._project.save()
  160. if self._project.last_error is not None:
  161. error_string = self._project.last_error
  162. QtWidgets.QMessageBox.warning(self.ui, 'Error', 'Error while saving (see log for details): %s'
  163. % error_string)
  164. self._project.reset_error()
  165. def _file_quit(self):
  166. # TODO: warning if there are any unsaved changes
  167. self.close()
  168. def _help_about(self):
  169. about = AboutDialog()
  170. about.exec_()
  171. #####################################
  172. # Prediction actions and management #
  173. #####################################
  174. def _predict_via_dialog(self):
  175. # '''
  176. # {'filename': filename, 'filetype': filetype, 'jobs': ['detect-faces']}
  177. image_extensions = ['.jpg', '.png', '.jpeg']
  178. video_extensions = ['.mp4', '.avi', '.mkv', '.mpg', '.mpeg']
  179. name_list = ' '.join(['*' + extension for extension in image_extensions + video_extensions])
  180. extensions = f'All supported files ({name_list})'
  181. jobs = []
  182. selection = QtWidgets.QFileDialog.getOpenFileNames(self.ui, 'Select Files To Add', filter=extensions)
  183. for filename in selection[0]:
  184. _, extension = os.path.splitext(filename)
  185. extension = extension.lower()
  186. if extension in image_extensions:
  187. jobs.append({
  188. 'filename': filename,
  189. 'filetype': 'image',
  190. 'jobs': ['detect-faces']
  191. })
  192. elif extension in video_extensions:
  193. jobs.append({
  194. 'filename': filename,
  195. 'filetype': 'image',
  196. 'jobs': ['detect-faces']
  197. })
  198. progress_dialog = QtWidgets.QProgressDialog('Executing prediction jobs...', 'Abort', 0, 1000, self.ui)
  199. progress_dialog.setWindowModality(QtCore.Qt.WindowModal)
  200. def callback(progress):
  201. progress_dialog.setValue(int(progress * 1000))
  202. for i in range(10):
  203. QtWidgets.QApplication.processEvents()
  204. self._project.execute(jobs, callback=callback)
  205. progress_dialog.close()
  206. self._predictions += jobs
  207. '''
  208. TODO: remove PredictionDialog.py
  209. prediction_dialog = PredictionDialog(parent=self.ui)
  210. if prediction_dialog.exec_():
  211. # TODO extract callback and progress dialog to make it reusable
  212. jobs = copy.deepcopy(prediction_dialog.jobs)
  213. print(jobs)
  214. progress_dialog = QtWidgets.QProgressDialog('Executing prediction jobs...', 'Abort', 0, 1000, self.ui)
  215. progress_dialog.setWindowModality(QtCore.Qt.WindowModal)
  216. def callback(progress):
  217. progress_dialog.setValue(int(progress * 1000))
  218. for i in range(10):
  219. QtWidgets.QApplication.processEvents()
  220. self._project.execute(jobs, callback=callback)
  221. progress_dialog.close()
  222. self._predictions += jobs
  223. '''
  224. def _update_prediction_ui(self):
  225. # TODO implement better sync
  226. self.ui.imageThumbnailGallery.clear()
  227. for prediction in self._predictions:
  228. filename = prediction['filename']
  229. if prediction['filetype'] == 'image':
  230. icon = QtGui.QIcon(filename)
  231. item = QtWidgets.QListWidgetItem(icon, os.path.basename(filename))
  232. else:
  233. item = QtWidgets.QListWidgetItem(os.path.basename(filename))
  234. self.ui.imageThumbnailGallery.addItem(item)
  235. self._prediction_row_changed(self.ui.imageThumbnailGallery.currentRow())
  236. if len(self._predictions) == 0:
  237. self._prediction_row_changed(-1)
  238. def _prediction_row_changed(self, row):
  239. if row in range(len(self._predictions)):
  240. prediction = self._predictions[row]
  241. self._setup_video_slider(prediction)
  242. self._display_prediction(prediction)
  243. else:
  244. self._setup_video_slider({'filetype': 'empty'})
  245. self._display_prediction({'filetype': 'empty'})
  246. def _setup_video_slider(self, prediction):
  247. if prediction['filetype'] == 'video':
  248. video_cap = Video(prediction['filename'])
  249. if video_cap.last_error is None:
  250. self.ui.frameNumber.setDecMode()
  251. self.ui.frameNumber.display(0)
  252. self.ui.frameNumber.setDigitCount(5)
  253. self.ui.frameNumber.show()
  254. self.ui.frameSlider.show()
  255. self.ui.frameSlider.setValue(0)
  256. self.ui.frameSlider.setMinimum(0)
  257. self.ui.frameSlider.setMaximum(video_cap.framecount - 1)
  258. self.ui.frameSlider.valueChanged.connect(
  259. lambda value, _prediction=prediction:
  260. [self._display_prediction(dict(frame=value, cap=video_cap, **_prediction)),
  261. self.ui.frameNumber.display(value)])
  262. else:
  263. self._disable_video_slider()
  264. else:
  265. self._disable_video_slider()
  266. def _disable_video_slider(self):
  267. self.ui.frameSlider.setValue(0)
  268. try:
  269. self.ui.frameSlider.valueChanged.disconnect()
  270. except:
  271. pass
  272. self.ui.frameSlider.hide()
  273. self.ui.frameNumber.hide()
  274. def _display_prediction(self, prediction):
  275. self.ui.annotatedImageView.display(prediction)
  276. filtered_dict = {i: prediction[i] for i in prediction if i != 'cap'}
  277. if 'cap' in prediction.keys():
  278. filtered_dict['cap'] = '[VideoCapture]'
  279. self.ui.predictionDebugText.setPlainText(json.dumps(filtered_dict, indent=2, skipkeys=True))