Project.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. from json import load
  2. from os import path, mkdir, listdir
  3. from os.path import splitext
  4. from uuid import uuid1
  5. from eventlet import spawn_after
  6. from pycs.observable import ObservableDict
  7. from pycs.pipeline.PipelineManager import PipelineManager
  8. from pycs.projects.ImageFile import ImageFile
  9. from pycs.projects.VideoFile import VideoFile
  10. from pycs.util.RecursiveDictionary import set_recursive
  11. class Project(ObservableDict):
  12. DEFAULT_PIPELINE_TIMEOUT = 120
  13. def __init__(self, obj: dict, parent):
  14. self.pipeline_manager = None
  15. self.quit_pipeline_thread = None
  16. # ensure all required object keys are available
  17. for key in ['data', 'labels', 'jobs']:
  18. if key not in obj.keys():
  19. obj[key] = {}
  20. # load model data
  21. folder = path.join('projects', obj['id'], 'model')
  22. with open(path.join(folder, 'distribution.json'), 'r') as file:
  23. model = load(file)
  24. model['path'] = folder
  25. obj['model'] = model
  26. # handle unmanaged files
  27. if obj['unmanaged'] is not None:
  28. for file in listdir(obj['unmanaged']):
  29. if file not in obj['data'].keys():
  30. name, ext = splitext(file)
  31. uuid = name
  32. obj['data'][uuid] = self.create_media_file_dict(uuid, name, ext, 0, 0)
  33. # save data as MediaFile objects
  34. for key in obj['data'].keys():
  35. obj['data'][key] = self.create_media_file(obj['data'][key], self)
  36. # initialize super
  37. super().__init__(obj, parent)
  38. # create data and temp
  39. data_path = path.join('projects', self['id'], 'data')
  40. if not path.exists(data_path):
  41. mkdir(data_path)
  42. temp_path = path.join('projects', self['id'], 'temp')
  43. if not path.exists(temp_path):
  44. mkdir(temp_path)
  45. # subscribe to changes to write to disk afterwards
  46. self.subscribe(lambda d, k: self.parent.write_project(self['id']))
  47. def update_properties(self, update):
  48. set_recursive(update, self)
  49. def new_media_file_path(self):
  50. return path.join('projects', self['id'], 'data'), str(uuid1())
  51. @staticmethod
  52. def create_media_file_dict(uuid, name, extension, size, created):
  53. return {
  54. 'id': uuid,
  55. 'name': name,
  56. 'extension': extension,
  57. 'size': size,
  58. 'created': created
  59. }
  60. def create_media_file(self, file, project=None):
  61. if project is None:
  62. project = self
  63. # TODO check file extension
  64. # TODO determine type
  65. # TODO filter supported types
  66. if file['extension'] in ['.jpg', '.png']:
  67. return ImageFile(file, project)
  68. if file['extension'] in ['.mp4']:
  69. return VideoFile(file, project)
  70. raise NotImplementedError
  71. def add_media_file(self, uuid, name, extension, size, created):
  72. file = self.create_media_file_dict(uuid, name, extension, size, created)
  73. self['data'][file['id']] = self.create_media_file(file)
  74. def remove_media_file(self, file_id):
  75. del self['data'][file_id]
  76. def add_label(self, name):
  77. label_uuid = str(uuid1())
  78. self['labels'][label_uuid] = {
  79. 'id': label_uuid,
  80. 'name': name
  81. }
  82. def update_label(self, identifier, name):
  83. if identifier in self['labels']:
  84. self['labels'][identifier]['name'] = name
  85. def remove_label(self, identifier):
  86. # abort if identifier is unknown
  87. if identifier not in self['labels']:
  88. return
  89. # remove label from data elements
  90. remove = list()
  91. for data in self['data']:
  92. for pred in self['data'][data]['predictionResults']:
  93. if 'label' in self['data'][data]['predictionResults'][pred]:
  94. if self['data'][data]['predictionResults'][pred]['label'] == identifier:
  95. remove.append((data, pred))
  96. for t in remove:
  97. del self['data'][t[0]]['predictionResults'][t[1]]
  98. # remove label from list
  99. del self['labels'][identifier]
  100. def predict(self, identifiers):
  101. # create pipeline
  102. pipeline = self.__create_pipeline()
  103. # run jobs
  104. for file_id in identifiers:
  105. if file_id in self['data'].keys():
  106. pipeline.run(self['data'][file_id])
  107. # schedule timeout thread
  108. self.quit_pipeline_thread = spawn_after(self.DEFAULT_PIPELINE_TIMEOUT, self.__quit_pipeline)
  109. def fit(self):
  110. # create pipeline
  111. pipeline = self.__create_pipeline()
  112. # run fit
  113. pipeline.fit()
  114. # schedule timeout thread
  115. self.quit_pipeline_thread = spawn_after(self.DEFAULT_PIPELINE_TIMEOUT, self.__quit_pipeline)
  116. def __create_pipeline(self):
  117. # abort pipeline termination
  118. self.__quit_pipeline_thread()
  119. # create pipeline if it does not exist already
  120. if self.pipeline_manager is None:
  121. self.pipeline_manager = PipelineManager(self)
  122. return self.pipeline_manager
  123. def __quit_pipeline(self):
  124. if self.pipeline_manager is not None:
  125. self.pipeline_manager.close()
  126. self.pipeline_manager = None
  127. self.quit_pipeline_thread = None
  128. def __create_quit_pipeline_thread(self):
  129. # abort pipeline termination
  130. self.__quit_pipeline_thread()
  131. # create new thread
  132. self.quit_pipeline_thread = spawn_after(self.DEFAULT_PIPELINE_TIMEOUT, self.__quit_pipeline)
  133. def __quit_pipeline_thread(self):
  134. if self.quit_pipeline_thread is not None:
  135. self.quit_pipeline_thread.cancel()
  136. self.quit_pipeline_thread = None