Project.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. from json import load
  2. from os import path
  3. from uuid import uuid1
  4. from eventlet import spawn_after
  5. from pycs.observable import ObservableDict
  6. from pycs.pipeline.PipelineManager import PipelineManager
  7. from pycs.projects.MediaFile import MediaFile
  8. from pycs.util.RecursiveDictionary import set_recursive
  9. class Project(ObservableDict):
  10. DEFAULT_PIPELINE_TIMEOUT = 120
  11. def __init__(self, obj: dict, parent):
  12. self.pipeline_manager = None
  13. self.quit_pipeline_thread = None
  14. # ensure all required object keys are available
  15. for key in ['data', 'labels', 'jobs']:
  16. if key not in obj.keys():
  17. obj[key] = {}
  18. # load model data
  19. folder = path.join('projects', obj['id'], 'model')
  20. with open(path.join(folder, 'distribution.json'), 'r') as file:
  21. model = load(file)
  22. model['path'] = folder
  23. obj['model'] = model
  24. # save data as MediaFile objects
  25. for key in obj['data'].keys():
  26. obj['data'][key] = MediaFile(obj['data'][key], obj['id'])
  27. # initialize super
  28. super().__init__(obj, parent)
  29. # subscribe to changes to write to disk afterwards
  30. self.subscribe(lambda d, k: self.parent.write_project(self['id']))
  31. def update_properties(self, update):
  32. set_recursive(update, self)
  33. def new_media_file_path(self):
  34. return path.join('projects', self['id'], 'data'), str(uuid1())
  35. def add_media_file(self, file):
  36. file = MediaFile(file, self['id'])
  37. self['data'][file['id']] = file
  38. def remove_media_file(self, file_id):
  39. del self['data'][file_id]
  40. def add_label(self, name):
  41. label_uuid = str(uuid1())
  42. self['labels'][label_uuid] = {
  43. 'id': label_uuid,
  44. 'name': name
  45. }
  46. def update_label(self, identifier, name):
  47. if identifier in self['labels']:
  48. self['labels'][identifier]['name'] = name
  49. def remove_label(self, identifier):
  50. # abort if identifier is unknown
  51. if identifier not in self['labels']:
  52. return
  53. # remove label from data elements
  54. remove = list()
  55. for data in self['data']:
  56. for pred in self['data'][data]['predictionResults']:
  57. if 'label' in self['data'][data]['predictionResults'][pred]:
  58. if self['data'][data]['predictionResults'][pred]['label'] == identifier:
  59. remove.append((data, pred))
  60. for t in remove:
  61. del self['data'][t[0]]['predictionResults'][t[1]]
  62. # remove label from list
  63. del self['labels'][identifier]
  64. def predict(self, identifiers):
  65. # create pipeline
  66. pipeline = self.__create_pipeline()
  67. # run jobs
  68. for file_id in identifiers:
  69. if file_id in self['data'].keys():
  70. pipeline.run(self['data'][file_id])
  71. # schedule timeout thread
  72. self.quit_pipeline_thread = spawn_after(self.DEFAULT_PIPELINE_TIMEOUT, self.__quit_pipeline)
  73. def fit(self):
  74. # create pipeline
  75. pipeline = self.__create_pipeline()
  76. # run fit
  77. pipeline.fit()
  78. # schedule timeout thread
  79. self.quit_pipeline_thread = spawn_after(self.DEFAULT_PIPELINE_TIMEOUT, self.__quit_pipeline)
  80. def __create_pipeline(self):
  81. # abort pipeline termination
  82. self.__quit_pipeline_thread()
  83. # create pipeline if it does not exist already
  84. if self.pipeline_manager is None:
  85. self.pipeline_manager = PipelineManager(self)
  86. return self.pipeline_manager
  87. def __quit_pipeline(self):
  88. if self.pipeline_manager is not None:
  89. self.pipeline_manager.close()
  90. self.pipeline_manager = None
  91. self.quit_pipeline_thread = None
  92. def __create_quit_pipeline_thread(self):
  93. # abort pipeline termination
  94. self.__quit_pipeline_thread()
  95. # create new thread
  96. self.quit_pipeline_thread = spawn_after(self.DEFAULT_PIPELINE_TIMEOUT, self.__quit_pipeline)
  97. def __quit_pipeline_thread(self):
  98. if self.quit_pipeline_thread is not None:
  99. self.quit_pipeline_thread.cancel()
  100. self.quit_pipeline_thread = None