6
0

Pipeline.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. from typing import List
  2. from pycs.interfaces.MediaFile import MediaFile
  3. from pycs.interfaces.MediaBoundingBox import MediaBoundingBox
  4. from pycs.interfaces.MediaStorage import MediaStorage
  5. class Pipeline:
  6. """
  7. pipeline interface that should be implemented by model developers
  8. """
  9. #pylint: disable=unnecessary-pass
  10. def __init__(self, root_folder: str, distribution: dict):
  11. """
  12. prepare everything needed to run jobs later
  13. :param root_folder: relative path to model folder
  14. :param distribution: dict parsed from distribution.json
  15. """
  16. pass
  17. #pylint: disable=unnecessary-pass
  18. def close(self):
  19. """
  20. is called everytime a pipeline is not needed anymore and should be used
  21. to free native resources
  22. :return:
  23. """
  24. pass
  25. #pylint: disable=no-self-use
  26. def collections(self) -> List[dict]:
  27. """
  28. is called while initializing a pipeline to receive available
  29. collections
  30. :return: list of collections or None
  31. """
  32. return []
  33. @staticmethod
  34. def create_collection(reference: str,
  35. name: str,
  36. description: str = None,
  37. autoselect: bool = False) -> dict:
  38. """
  39. create a collection dict
  40. :param reference: unique reference
  41. :param name: collection name
  42. :param description: collection description
  43. :param autoselect: show this collection by default if it contains elements
  44. :return: collection dict
  45. """
  46. return {
  47. 'reference': reference,
  48. 'name': name,
  49. 'description': description,
  50. 'autoselect': autoselect
  51. }
  52. def execute(self, storage: MediaStorage, file: MediaFile):
  53. """
  54. receive a file, create predictions and add them to the object
  55. :param storage: database abstraction object
  56. :param file: which should be analyzed
  57. """
  58. raise NotImplementedError
  59. def pure_inference(self, storage: MediaStorage, file: MediaFile,
  60. bounding_boxes: List[MediaBoundingBox]):
  61. """
  62. receive a file and a list of bounding boxes and only create a
  63. classification for the given bounding boxes.
  64. :param storage: database abstraction object
  65. :param file: which should be analyzed
  66. :param bounding_boxes: only perform inference for the given bounding boxes
  67. :return: labels for the given bounding boxes
  68. """
  69. raise NotImplementedError
  70. def fit(self, storage: MediaStorage):
  71. """
  72. receive a list of annotated media files and adapt the underlying model
  73. :param storage: database abstraction object
  74. """
  75. raise NotImplementedError