6
0

CreateProject.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. from contextlib import closing
  2. from os import mkdir
  3. from os import path
  4. from pathlib import Path
  5. from shutil import copytree
  6. from uuid import uuid1
  7. from flask import abort
  8. from flask import make_response
  9. from flask import request
  10. from flask.views import View
  11. from pycs import app
  12. from pycs import db
  13. from pycs.database.LabelProvider import LabelProvider
  14. from pycs.database.Model import Model
  15. from pycs.database.Project import Project
  16. from pycs.frontend.endpoints.projects.ExecuteExternalStorage import ExecuteExternalStorage
  17. from pycs.frontend.endpoints.projects.ExecuteLabelProvider import ExecuteLabelProvider
  18. from pycs.frontend.notifications.NotificationManager import NotificationManager
  19. from pycs.jobs.JobRunner import JobRunner
  20. from pycs.util.PipelineUtil import load_from_root_folder as load_pipeline
  21. class CreateProject(View):
  22. """
  23. create a project, insert data into database and create directories
  24. """
  25. # pylint: disable=arguments-differ
  26. methods = ['POST']
  27. def __init__(self, nm: NotificationManager, jobs: JobRunner):
  28. # pylint: disable=invalid-name
  29. self.nm = nm
  30. self.jobs = jobs
  31. @property
  32. def project_folder(self):
  33. return app.config["TEST_PROJECTS_DIR"] if app.config["TESTING"] else 'projects'
  34. def dispatch_request(self):
  35. # extract request data
  36. data = request.get_json(force=True)
  37. if None in [data.get('name'), data.get('description')]:
  38. return abort(400, "name and description information missing!")
  39. name = data['name']
  40. description = data['description']
  41. model_id = data['model']
  42. label_provider_id = data['label']
  43. data_folder = data['external']
  44. external_data = data_folder is not None
  45. # find model
  46. model = Model.query.get(model_id)
  47. if model is None:
  48. return abort(404, "Model not found")
  49. # find label provider
  50. if label_provider_id is None:
  51. label_provider = None
  52. else:
  53. label_provider = LabelProvider.query.get(label_provider_id)
  54. if label_provider is None:
  55. return abort(404, "Label provider not found")
  56. # create project folder
  57. project_folder = Path(self.project_folder, str(uuid1()))
  58. project_folder.mkdir(parents=True)
  59. temp_folder = project_folder / 'temp'
  60. temp_folder.mkdir()
  61. # check project data directory
  62. if external_data:
  63. # check if exists
  64. if not path.exists(data_folder):
  65. return abort(400, "Data folder does not exist!")
  66. else:
  67. data_folder = project_folder / 'data'
  68. data_folder.mkdir()
  69. # copy model to project folder
  70. model_folder = project_folder / 'model'
  71. copytree(model.root_folder, str(model_folder))
  72. model, _ = model.copy_to(f'{model.name} ({name})', str(model_folder))
  73. # create entry in database
  74. project = Project.new(
  75. name=name,
  76. description=description,
  77. model_id=model.id,
  78. label_provider_id=label_provider.id,
  79. root_folder=str(project_folder),
  80. external_data=external_data,
  81. data_folder=str(data_folder)
  82. )
  83. # execute label provider and add labels to project
  84. if label_provider is not None:
  85. ExecuteLabelProvider.execute_label_provider(self.nm, self.jobs, project,
  86. label_provider)
  87. root_folder = model.root_folder
  88. # load model and add collections to the project
  89. def load_model_and_get_collections():
  90. with closing(load_pipeline(root_folder)) as pipeline:
  91. return pipeline.collections()
  92. project_id = project.id
  93. def add_collections_to_project(provided_collections):
  94. project = Project.query.get(project_id)
  95. with db.session.begin_nested():
  96. for position, collection in enumerate(provided_collections):
  97. project.create_collection(
  98. collection['reference'],
  99. collection['name'],
  100. collection['description'],
  101. position + 1,
  102. collection['autoselect'],
  103. commit=False,
  104. )
  105. self.jobs.run(project,
  106. 'Media Collections',
  107. f'{project.name}',
  108. f'{project.id}/media-collections',
  109. executable=load_model_and_get_collections,
  110. result=add_collections_to_project)
  111. # find media files
  112. if external_data:
  113. ExecuteExternalStorage.find_media_files(self.nm, self.jobs, project)
  114. # fire event
  115. self.nm.create_model(model.id)
  116. self.nm.create_project(project.id)
  117. # return success response
  118. return make_response()