CreateProject.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. import shutil
  2. import uuid
  3. from contextlib import closing
  4. from pathlib import Path
  5. from flask import abort
  6. from flask import make_response
  7. from flask import request
  8. from flask.views import View
  9. from pycs import db
  10. from pycs import settings
  11. from pycs.database.LabelProvider import LabelProvider
  12. from pycs.database.Model import Model
  13. from pycs.database.Project import Project
  14. from pycs.frontend.endpoints.projects.ExecuteExternalStorage import ExecuteExternalStorage
  15. from pycs.frontend.endpoints.projects.ExecuteLabelProvider import ExecuteLabelProvider
  16. from pycs.frontend.notifications.NotificationManager import NotificationManager
  17. from pycs.jobs.JobRunner import JobRunner
  18. from pycs.util.PipelineUtil import load_from_root_folder as load_pipeline
  19. class CreateProject(View):
  20. """
  21. create a project, insert data into database and create directories
  22. """
  23. # pylint: disable=arguments-differ
  24. methods = ['POST']
  25. def __init__(self, nm: NotificationManager, jobs: JobRunner):
  26. # pylint: disable=invalid-name
  27. self.nm = nm
  28. self.jobs = jobs
  29. def dispatch_request(self):
  30. # extract request data
  31. data = request.get_json(force=True)
  32. name = data.get('name')
  33. description = data.get('description')
  34. if name is None:
  35. abort(400, "name argument is missing!")
  36. if description is None:
  37. abort(400, "description argument is missing!")
  38. model_id = int(data['model'])
  39. model = Model.get_or_404(model_id)
  40. label_provider_id = data.get('label')
  41. label_provider = None
  42. if label_provider_id is not None:
  43. label_provider = LabelProvider.get_or_404(label_provider_id)
  44. # create project folder
  45. project_folder = Path(settings.projects_folder, str(uuid.uuid1()))
  46. project_folder.mkdir(parents=True)
  47. temp_folder = project_folder / 'temp'
  48. temp_folder.mkdir()
  49. # check project data directory
  50. if data['external'] is None:
  51. external_data = False
  52. data_folder = project_folder / 'data'
  53. data_folder.mkdir()
  54. else:
  55. external_data = True
  56. data_folder = Path(data['external'])
  57. # check if exists
  58. if not data_folder.exists():
  59. return abort(400, f"External folder does not exist: {data_folder}")
  60. # copy model to project folder
  61. model_folder = project_folder / 'model'
  62. shutil.copytree(model.root_folder, str(model_folder))
  63. with db.session.begin_nested():
  64. model, is_new = model.copy_to(
  65. name=f'{model.name} ({name})',
  66. root_folder=str(model_folder),
  67. commit=False)
  68. model.flush()
  69. if not is_new:
  70. # pragma: no cover
  71. abort(400,
  72. f"Could not copy model! Model in \"{model_folder}\" already exists!")
  73. project = Project.new(name=name,
  74. description=description,
  75. model_id=model.id,
  76. label_provider_id=label_provider_id,
  77. root_folder=str(project_folder),
  78. external_data=external_data,
  79. data_folder=str(data_folder))
  80. # execute label provider and add labels to project
  81. if label_provider is not None:
  82. ExecuteLabelProvider.execute_label_provider(self.nm, self.jobs, project,
  83. label_provider)
  84. # load model and add collections to the project
  85. def load_model_and_get_collections():
  86. with closing(load_pipeline(model.root_folder)) as pipeline:
  87. return pipeline.collections()
  88. def add_collections_to_project(provided_collections):
  89. with db.session.begin_nested():
  90. for position, collection in enumerate(provided_collections, 1):
  91. project.create_collection(commit=False,
  92. position=position,
  93. **collection)
  94. self.jobs.run(project,
  95. 'Media Collections',
  96. f'{project.name}',
  97. f'{project.id}/media-collections',
  98. executable=load_model_and_get_collections,
  99. result=add_collections_to_project)
  100. # find media files
  101. if external_data:
  102. ExecuteExternalStorage.find_media_files(self.nm, self.jobs, project)
  103. # fire event
  104. self.nm.create_model(model)
  105. self.nm.create_project(project)
  106. # return success response
  107. return make_response()