Project.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. from json import load
  2. from os import path
  3. from uuid import uuid1
  4. from eventlet import spawn_after
  5. from pycs.observable import ObservableDict
  6. from pycs.pipeline.PipelineManager import PipelineManager
  7. from pycs.projects.ImageFile import ImageFile
  8. from pycs.projects.VideoFile import VideoFile
  9. from pycs.util.RecursiveDictionary import set_recursive
  10. class Project(ObservableDict):
  11. DEFAULT_PIPELINE_TIMEOUT = 120
  12. def __init__(self, obj: dict, parent):
  13. self.pipeline_manager = None
  14. self.quit_pipeline_thread = None
  15. # ensure all required object keys are available
  16. for key in ['data', 'labels', 'jobs']:
  17. if key not in obj.keys():
  18. obj[key] = {}
  19. # load model data
  20. folder = path.join('projects', obj['id'], 'model')
  21. with open(path.join(folder, 'distribution.json'), 'r') as file:
  22. model = load(file)
  23. model['path'] = folder
  24. obj['model'] = model
  25. # save data as MediaFile objects
  26. for key in obj['data'].keys():
  27. obj['data'][key] = self.create_media_file(obj['data'][key], obj['id'])
  28. # initialize super
  29. super().__init__(obj, parent)
  30. # subscribe to changes to write to disk afterwards
  31. self.subscribe(lambda d, k: self.parent.write_project(self['id']))
  32. def update_properties(self, update):
  33. set_recursive(update, self)
  34. def new_media_file_path(self):
  35. return path.join('projects', self['id'], 'data'), str(uuid1())
  36. def create_media_file(self, file, project_id=None):
  37. if project_id is None:
  38. project_id = self['id']
  39. # TODO check file extension
  40. # TODO determine type
  41. # TODO filter supported types
  42. if file['extension'] in ['.jpg', '.png']:
  43. return ImageFile(file, project_id)
  44. if file['extension'] in ['.mp4']:
  45. return VideoFile(file, project_id)
  46. raise NotImplementedError
  47. def add_media_file(self, uuid, name, extension, size, created):
  48. file = {
  49. 'id': uuid,
  50. 'name': name,
  51. 'extension': extension,
  52. 'size': size,
  53. 'created': created
  54. }
  55. self['data'][file['id']] = self.create_media_file(file)
  56. def remove_media_file(self, file_id):
  57. del self['data'][file_id]
  58. def add_label(self, name):
  59. label_uuid = str(uuid1())
  60. self['labels'][label_uuid] = {
  61. 'id': label_uuid,
  62. 'name': name
  63. }
  64. def update_label(self, identifier, name):
  65. if identifier in self['labels']:
  66. self['labels'][identifier]['name'] = name
  67. def remove_label(self, identifier):
  68. # abort if identifier is unknown
  69. if identifier not in self['labels']:
  70. return
  71. # remove label from data elements
  72. remove = list()
  73. for data in self['data']:
  74. for pred in self['data'][data]['predictionResults']:
  75. if 'label' in self['data'][data]['predictionResults'][pred]:
  76. if self['data'][data]['predictionResults'][pred]['label'] == identifier:
  77. remove.append((data, pred))
  78. for t in remove:
  79. del self['data'][t[0]]['predictionResults'][t[1]]
  80. # remove label from list
  81. del self['labels'][identifier]
  82. def predict(self, identifiers):
  83. # create pipeline
  84. pipeline = self.__create_pipeline()
  85. # run jobs
  86. for file_id in identifiers:
  87. if file_id in self['data'].keys():
  88. pipeline.run(self['data'][file_id])
  89. # schedule timeout thread
  90. self.quit_pipeline_thread = spawn_after(self.DEFAULT_PIPELINE_TIMEOUT, self.__quit_pipeline)
  91. def fit(self):
  92. # create pipeline
  93. pipeline = self.__create_pipeline()
  94. # run fit
  95. pipeline.fit()
  96. # schedule timeout thread
  97. self.quit_pipeline_thread = spawn_after(self.DEFAULT_PIPELINE_TIMEOUT, self.__quit_pipeline)
  98. def __create_pipeline(self):
  99. # abort pipeline termination
  100. self.__quit_pipeline_thread()
  101. # create pipeline if it does not exist already
  102. if self.pipeline_manager is None:
  103. self.pipeline_manager = PipelineManager(self)
  104. return self.pipeline_manager
  105. def __quit_pipeline(self):
  106. if self.pipeline_manager is not None:
  107. self.pipeline_manager.close()
  108. self.pipeline_manager = None
  109. self.quit_pipeline_thread = None
  110. def __create_quit_pipeline_thread(self):
  111. # abort pipeline termination
  112. self.__quit_pipeline_thread()
  113. # create new thread
  114. self.quit_pipeline_thread = spawn_after(self.DEFAULT_PIPELINE_TIMEOUT, self.__quit_pipeline)
  115. def __quit_pipeline_thread(self):
  116. if self.quit_pipeline_thread is not None:
  117. self.quit_pipeline_thread.cancel()
  118. self.quit_pipeline_thread = None