CreateProject.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. from contextlib import closing
  2. from os import mkdir
  3. from os import path
  4. from pathlib import Path
  5. from shutil import copytree
  6. from uuid import uuid1
  7. from flask import abort
  8. from flask import make_response
  9. from flask import request
  10. from flask.views import View
  11. from pycs import app
  12. from pycs import db
  13. from pycs.database.LabelProvider import LabelProvider
  14. from pycs.database.Model import Model
  15. from pycs.database.Project import Project
  16. from pycs.frontend.endpoints.projects.ExecuteExternalStorage import ExecuteExternalStorage
  17. from pycs.frontend.endpoints.projects.ExecuteLabelProvider import ExecuteLabelProvider
  18. from pycs.frontend.notifications.NotificationManager import NotificationManager
  19. from pycs.jobs.JobRunner import JobRunner
  20. from pycs.util.PipelineUtil import load_from_root_folder as load_pipeline
  21. class CreateProject(View):
  22. """
  23. create a project, insert data into database and create directories
  24. """
  25. # pylint: disable=arguments-differ
  26. methods = ['POST']
  27. @property
  28. def project_folder(self):
  29. return app.config["TEST_PROJECTS_DIR"] if app.config["TESTING"] else 'projects'
  30. def dispatch_request(self):
  31. # extract request data
  32. data = request.get_json(force=True)
  33. if None in [data.get('name'), data.get('description')]:
  34. return abort(400, "name and description information missing!")
  35. name = data['name']
  36. description = data['description']
  37. model_id = data['model']
  38. label_provider_id = data['label']
  39. data_folder = data['external']
  40. external_data = data_folder is not None
  41. # find model
  42. model = Model.query.get(model_id)
  43. if model is None:
  44. return abort(404, "Model not found")
  45. # find label provider
  46. if label_provider_id is None:
  47. label_provider = None
  48. else:
  49. label_provider = LabelProvider.query.get(label_provider_id)
  50. if label_provider is None:
  51. return abort(404, "Label provider not found")
  52. # create project folder
  53. project_folder = Path(self.project_folder, str(uuid1()))
  54. project_folder.mkdir(parents=True)
  55. temp_folder = project_folder / 'temp'
  56. temp_folder.mkdir()
  57. # check project data directory
  58. if external_data:
  59. # check if exists
  60. if not path.exists(data_folder):
  61. return abort(400, "Data folder does not exist!")
  62. else:
  63. data_folder = project_folder / 'data'
  64. data_folder.mkdir()
  65. # copy model to project folder
  66. model_folder = project_folder / 'model'
  67. copytree(model.root_folder, str(model_folder))
  68. model, _ = model.copy_to(f'{model.name} ({name})', str(model_folder))
  69. # create entry in database
  70. project = Project.new(
  71. name=name,
  72. description=description,
  73. model_id=model.id,
  74. label_provider_id=label_provider.id,
  75. root_folder=str(project_folder),
  76. external_data=external_data,
  77. data_folder=str(data_folder)
  78. )
  79. # execute label provider and add labels to project
  80. if label_provider is not None:
  81. ExecuteLabelProvider.execute_label_provider(project, label_provider)
  82. root_folder = model.root_folder
  83. # load model and add collections to the project
  84. def load_model_and_get_collections():
  85. with closing(load_pipeline(root_folder)) as pipeline:
  86. return pipeline.collections()
  87. project_id = project.id
  88. def add_collections_to_project(provided_collections):
  89. project = Project.query.get(project_id)
  90. with db.session.begin_nested():
  91. for position, collection in enumerate(provided_collections):
  92. project.create_collection(
  93. collection['reference'],
  94. collection['name'],
  95. collection['description'],
  96. position + 1,
  97. collection['autoselect'],
  98. commit=False,
  99. )
  100. JobRunner.run(project,
  101. 'Media Collections',
  102. f'{project.name}',
  103. f'{project.id}/media-collections',
  104. executable=load_model_and_get_collections,
  105. result=add_collections_to_project)
  106. # find media files
  107. if external_data:
  108. ExecuteExternalStorage.find_media_files(project)
  109. # fire event
  110. NotificationManager.created("model", model.id, Model)
  111. NotificationManager.created("project", project.id, Project)
  112. # return success response
  113. return make_response()