6
0

Project.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. from json import load
  2. from os import path, mkdir, listdir
  3. from os.path import splitext
  4. from uuid import uuid1
  5. from eventlet import spawn_after
  6. from pycs.observable import ObservableDict
  7. from pycs.pipeline.PipelineManager import PipelineManager
  8. from pycs.projects.ImageFile import ImageFile
  9. from pycs.projects.UnmanagedImageFile import UnmanagedImageFile
  10. from pycs.projects.UnmanagedVideoFile import UnmanagedVideoFile
  11. from pycs.projects.VideoFile import VideoFile
  12. from pycs.util.RecursiveDictionary import set_recursive
  13. class Project(ObservableDict):
  14. DEFAULT_PIPELINE_TIMEOUT = 120
  15. def __init__(self, obj: dict, parent):
  16. self.pipeline_manager = None
  17. self.quit_pipeline_thread = None
  18. self.unmanaged_files_keys = []
  19. self.unmanaged_files = {}
  20. # ensure all required object keys are available
  21. for key in ['data', 'labels', 'jobs']:
  22. if key not in obj.keys():
  23. obj[key] = {}
  24. # load model data
  25. folder = path.join('projects', obj['id'], 'model')
  26. with open(path.join(folder, 'distribution.json'), 'r') as file:
  27. model = load(file)
  28. model['path'] = folder
  29. obj['model'] = model
  30. # save data as MediaFile objects
  31. if obj['unmanaged'] is None:
  32. for key in obj['data'].keys():
  33. obj['data'][key] = self.create_media_file(obj['data'][key])
  34. # handle unmanaged files
  35. else:
  36. prev = None
  37. for file in listdir(obj['unmanaged']):
  38. uuid, ext = splitext(file)
  39. next = {
  40. 'id': uuid,
  41. 'extension': ext
  42. }
  43. next = self.create_media_file(next, unmanaged=True)
  44. if prev is not None:
  45. next.prev(prev)
  46. prev.next(next)
  47. prev = next
  48. self.unmanaged_files_keys.append(uuid)
  49. self.unmanaged_files[uuid] = next
  50. # initialize super
  51. super().__init__(obj, parent)
  52. # create data and temp
  53. data_path = path.join('projects', self['id'], 'data')
  54. if not path.exists(data_path):
  55. mkdir(data_path)
  56. temp_path = path.join('projects', self['id'], 'temp')
  57. if not path.exists(temp_path):
  58. mkdir(temp_path)
  59. # subscribe to changes to write to disk afterwards
  60. self.subscribe(lambda d, k: self.parent.write_project(self['id']))
  61. def update_properties(self, update):
  62. set_recursive(update, self)
  63. def get_media_file(self, identifier):
  64. if self['unmanaged']:
  65. if identifier not in self.unmanaged_files_keys:
  66. return None
  67. return self.unmanaged_files[identifier]
  68. else:
  69. if identifier not in self['data'].keys():
  70. return None
  71. return self['data'][identifier]
  72. def new_media_file_path(self):
  73. return path.join('projects', self['id'], 'data'), str(uuid1())
  74. def create_media_file(self, file, unmanaged=False):
  75. # TODO check file extension
  76. # TODO determine type
  77. # TODO filter supported types
  78. if file['extension'] in ['.jpg', '.png']:
  79. if unmanaged:
  80. return UnmanagedImageFile(file, self)
  81. else:
  82. return ImageFile(file, self)
  83. if file['extension'] in ['.mp4']:
  84. if unmanaged:
  85. return UnmanagedVideoFile(file, self)
  86. else:
  87. return VideoFile(file, self)
  88. raise NotImplementedError
  89. def add_media_file(self, uuid, name, extension, size, created):
  90. file = {
  91. 'id': uuid,
  92. 'name': name,
  93. 'extension': extension,
  94. 'size': size,
  95. 'created': created
  96. }
  97. self['data'][file['id']] = self.create_media_file(file)
  98. def remove_media_file(self, file_id):
  99. del self['data'][file_id]
  100. def add_label(self, name):
  101. label_uuid = str(uuid1())
  102. self['labels'][label_uuid] = {
  103. 'id': label_uuid,
  104. 'name': name
  105. }
  106. def update_label(self, identifier, name):
  107. if identifier in self['labels']:
  108. self['labels'][identifier]['name'] = name
  109. def remove_label(self, identifier):
  110. # abort if identifier is unknown
  111. if identifier not in self['labels']:
  112. return
  113. # remove label from data elements
  114. remove = list()
  115. for data in self['data']:
  116. for pred in self['data'][data]['predictionResults']:
  117. if 'label' in self['data'][data]['predictionResults'][pred]:
  118. if self['data'][data]['predictionResults'][pred]['label'] == identifier:
  119. remove.append((data, pred))
  120. for t in remove:
  121. del self['data'][t[0]]['predictionResults'][t[1]]
  122. # remove label from list
  123. del self['labels'][identifier]
  124. def predict(self, identifiers):
  125. # create pipeline
  126. pipeline = self.__create_pipeline()
  127. # run jobs
  128. for file_id in identifiers:
  129. if file_id in self['data'].keys():
  130. pipeline.run(self['data'][file_id])
  131. # schedule timeout thread
  132. self.quit_pipeline_thread = spawn_after(self.DEFAULT_PIPELINE_TIMEOUT, self.__quit_pipeline)
  133. def fit(self):
  134. # create pipeline
  135. pipeline = self.__create_pipeline()
  136. # run fit
  137. pipeline.fit()
  138. # schedule timeout thread
  139. self.quit_pipeline_thread = spawn_after(self.DEFAULT_PIPELINE_TIMEOUT, self.__quit_pipeline)
  140. def __create_pipeline(self):
  141. # abort pipeline termination
  142. self.__quit_pipeline_thread()
  143. # create pipeline if it does not exist already
  144. if self.pipeline_manager is None:
  145. self.pipeline_manager = PipelineManager(self)
  146. return self.pipeline_manager
  147. def __quit_pipeline(self):
  148. if self.pipeline_manager is not None:
  149. self.pipeline_manager.close()
  150. self.pipeline_manager = None
  151. self.quit_pipeline_thread = None
  152. def __create_quit_pipeline_thread(self):
  153. # abort pipeline termination
  154. self.__quit_pipeline_thread()
  155. # create new thread
  156. self.quit_pipeline_thread = spawn_after(self.DEFAULT_PIPELINE_TIMEOUT, self.__quit_pipeline)
  157. def __quit_pipeline_thread(self):
  158. if self.quit_pipeline_thread is not None:
  159. self.quit_pipeline_thread.cancel()
  160. self.quit_pipeline_thread = None