6
0

Project.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. from json import load
  2. from os import path, mkdir, listdir
  3. from os.path import splitext
  4. from uuid import uuid1
  5. from eventlet import spawn_after
  6. from pycs.observable import ObservableDict
  7. from pycs.pipeline.PipelineManager import PipelineManager
  8. from pycs.projects.ImageFile import ImageFile
  9. from pycs.projects.UnmanagedImageFile import UnmanagedImageFile
  10. from pycs.projects.UnmanagedVideoFile import UnmanagedVideoFile
  11. from pycs.projects.VideoFile import VideoFile
  12. from pycs.util.RecursiveDictionary import set_recursive
  13. class Project(ObservableDict):
  14. DEFAULT_PIPELINE_TIMEOUT = 120
  15. def __init__(self, obj: dict, parent):
  16. self.pipeline_manager = None
  17. self.quit_pipeline_thread = None
  18. self.unmanaged_files_keys = []
  19. self.unmanaged_files = {}
  20. # ensure all required object keys are available
  21. for key in ['data', 'labels', 'jobs']:
  22. if key not in obj.keys():
  23. obj[key] = {}
  24. # load model data
  25. folder = path.join('projects', obj['id'], 'model')
  26. with open(path.join(folder, 'distribution.json'), 'r') as file:
  27. model = load(file)
  28. model['path'] = folder
  29. obj['model'] = model
  30. # save data as MediaFile objects
  31. if obj['unmanaged'] is None:
  32. for key in obj['data'].keys():
  33. obj['data'][key] = self.create_media_file(obj['data'][key])
  34. # handle unmanaged files
  35. else:
  36. prev = None
  37. for file in listdir(obj['unmanaged']):
  38. uuid, ext = splitext(file)
  39. next = {
  40. 'id': uuid,
  41. 'extension': ext
  42. }
  43. next = self.create_media_file(next, unmanaged=True)
  44. if prev is not None:
  45. next.prev(prev)
  46. prev.next(next)
  47. prev = next
  48. self.unmanaged_files_keys.append(uuid)
  49. self.unmanaged_files[uuid] = next
  50. length = len(self.unmanaged_files_keys)
  51. for key in self.unmanaged_files:
  52. self.unmanaged_files[key].length(length)
  53. # initialize super
  54. super().__init__(obj, parent)
  55. # create data and temp
  56. data_path = path.join('projects', self['id'], 'data')
  57. if not path.exists(data_path):
  58. mkdir(data_path)
  59. temp_path = path.join('projects', self['id'], 'temp')
  60. if not path.exists(temp_path):
  61. mkdir(temp_path)
  62. # subscribe to changes to write to disk afterwards
  63. self.subscribe(lambda d, k: self.parent.write_project(self['id']))
  64. def update_properties(self, update):
  65. set_recursive(update, self)
  66. def get_media_file(self, identifier):
  67. if self['unmanaged']:
  68. if identifier not in self.unmanaged_files_keys:
  69. return None
  70. return self.unmanaged_files[identifier]
  71. else:
  72. if identifier not in self['data'].keys():
  73. return None
  74. return self['data'][identifier]
  75. def new_media_file_path(self):
  76. return path.join('projects', self['id'], 'data'), str(uuid1())
  77. def create_media_file(self, file, unmanaged=False):
  78. # TODO check file extension
  79. # TODO determine type
  80. # TODO filter supported types
  81. if file['extension'] in ['.jpg', '.png']:
  82. if unmanaged:
  83. return UnmanagedImageFile(file, self)
  84. else:
  85. return ImageFile(file, self)
  86. if file['extension'] in ['.mp4']:
  87. if unmanaged:
  88. return UnmanagedVideoFile(file, self)
  89. else:
  90. return VideoFile(file, self)
  91. raise NotImplementedError
  92. def add_media_file(self, uuid, name, extension, size, created):
  93. file = {
  94. 'id': uuid,
  95. 'name': name,
  96. 'extension': extension,
  97. 'size': size,
  98. 'created': created
  99. }
  100. self['data'][file['id']] = self.create_media_file(file)
  101. def remove_media_file(self, file_id):
  102. del self['data'][file_id]
  103. def add_label(self, name):
  104. label_uuid = str(uuid1())
  105. self['labels'][label_uuid] = {
  106. 'id': label_uuid,
  107. 'name': name
  108. }
  109. def update_label(self, identifier, name):
  110. if identifier in self['labels']:
  111. self['labels'][identifier]['name'] = name
  112. def remove_label(self, identifier):
  113. # abort if identifier is unknown
  114. if identifier not in self['labels']:
  115. return
  116. # remove label from data elements
  117. remove = list()
  118. for data in self['data']:
  119. for pred in self['data'][data]['predictionResults']:
  120. if 'label' in self['data'][data]['predictionResults'][pred]:
  121. if self['data'][data]['predictionResults'][pred]['label'] == identifier:
  122. remove.append((data, pred))
  123. for t in remove:
  124. del self['data'][t[0]]['predictionResults'][t[1]]
  125. # remove label from list
  126. del self['labels'][identifier]
  127. def predict(self, identifiers):
  128. # create pipeline
  129. pipeline = self.__create_pipeline()
  130. # run jobs
  131. for file_id in identifiers:
  132. if file_id in self['data'].keys():
  133. pipeline.run(self['data'][file_id])
  134. # schedule timeout thread
  135. self.quit_pipeline_thread = spawn_after(self.DEFAULT_PIPELINE_TIMEOUT, self.__quit_pipeline)
  136. def fit(self):
  137. # create pipeline
  138. pipeline = self.__create_pipeline()
  139. # run fit
  140. pipeline.fit()
  141. # schedule timeout thread
  142. self.quit_pipeline_thread = spawn_after(self.DEFAULT_PIPELINE_TIMEOUT, self.__quit_pipeline)
  143. def __create_pipeline(self):
  144. # abort pipeline termination
  145. self.__quit_pipeline_thread()
  146. # create pipeline if it does not exist already
  147. if self.pipeline_manager is None:
  148. self.pipeline_manager = PipelineManager(self)
  149. return self.pipeline_manager
  150. def __quit_pipeline(self):
  151. if self.pipeline_manager is not None:
  152. self.pipeline_manager.close()
  153. self.pipeline_manager = None
  154. self.quit_pipeline_thread = None
  155. def __create_quit_pipeline_thread(self):
  156. # abort pipeline termination
  157. self.__quit_pipeline_thread()
  158. # create new thread
  159. self.quit_pipeline_thread = spawn_after(self.DEFAULT_PIPELINE_TIMEOUT, self.__quit_pipeline)
  160. def __quit_pipeline_thread(self):
  161. if self.quit_pipeline_thread is not None:
  162. self.quit_pipeline_thread.cancel()
  163. self.quit_pipeline_thread = None