ExecuteLabelProvider.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. from contextlib import closing
  2. from flask import abort
  3. from flask import make_response
  4. from flask import request
  5. from flask.views import View
  6. from pycs import db
  7. from pycs.database.LabelProvider import LabelProvider
  8. from pycs.database.Label import Label
  9. from pycs.database.Project import Project
  10. from pycs.frontend.notifications.NotificationManager import NotificationManager
  11. from pycs.jobs.JobGroupBusyException import JobGroupBusyException
  12. from pycs.jobs.JobRunner import JobRunner
  13. from tqdm import tqdm
  14. class ExecuteLabelProvider(View):
  15. """db
  16. execute the label provider associated with a passed project identifier
  17. """
  18. # pylint: disable=arguments-differ
  19. methods = ['POST']
  20. def __init__(self, nm: NotificationManager, jobs: JobRunner):
  21. # pylint: disable=invalid-name
  22. self.nm = nm
  23. self.jobs = jobs
  24. def dispatch_request(self, project_id: int):
  25. # extract request data
  26. data = request.get_json(force=True)
  27. if not data.get('execute', False):
  28. abort(400, "execute flag is missing")
  29. # find project
  30. project = Project.get_or_404(project_id)
  31. # get label provider
  32. label_provider = project.label_provider
  33. if label_provider is None:
  34. abort(400, "This project does not have a label provider.")
  35. # execute label provider and add labels to project
  36. try:
  37. self.execute_label_provider(self.nm, self.jobs, project, label_provider)
  38. except JobGroupBusyException:
  39. abort(400, "Label provider already running.")
  40. return make_response()
  41. @staticmethod
  42. def execute_label_provider(nm: NotificationManager, jobs: JobRunner,
  43. project: Project, label_provider: LabelProvider):
  44. """
  45. start a job that loads and executes a label provider and saves its results to the
  46. database afterwards
  47. :param nm: notification manager object
  48. :param jobs: job runner object
  49. :param project: project
  50. :param label_provider: label provider
  51. :return:
  52. """
  53. # pylint: disable=invalid-name
  54. # receive loads and executes the given label provider
  55. def receive():
  56. with closing(label_provider.load()) as label_provider_impl:
  57. return label_provider_impl.get_labels()
  58. project_id = project.id
  59. # result adds the received labels to the database and fires events
  60. def result(provided_labels):
  61. with db.session.begin():
  62. project = Project.query.get(project_id)
  63. labels = project.bulk_create_labels(provided_labels)
  64. # run job with given functions
  65. jobs.run(project,
  66. 'Label Provider',
  67. f'{project.name} ({label_provider.name})',
  68. f'{project.id}/label-provider',
  69. receive,
  70. result=result)