6
0

Pipeline.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. from typing import List
  2. from pycs.interfaces.MediaFile import MediaFile
  3. from pycs.interfaces.MediaBoundingBox import MediaBoundingBox
  4. from pycs.interfaces.MediaStorage import MediaStorage
  5. class Pipeline:
  6. """
  7. pipeline interface that should be implemented by model developers
  8. """
  9. #pylint: disable=unnecessary-pass
  10. def __init__(self, root_folder: str, distribution: dict):
  11. """
  12. prepare everything needed to run jobs later
  13. :param root_folder: relative path to model folder
  14. :param distribution: dict parsed from distribution.json
  15. """
  16. pass
  17. #pylint: disable=unnecessary-pass
  18. def close(self):
  19. """
  20. is called everytime a pipeline is not needed anymore and should be used
  21. to free native resources
  22. :return:
  23. """
  24. pass
  25. #pylint: disable=no-self-use
  26. def collections(self) -> List[dict]:
  27. """
  28. is called while initializing a pipeline to receive available
  29. collections
  30. :return: list of collections or None
  31. """
  32. return []
  33. @staticmethod
  34. def create_collection(reference: str,
  35. name: str,
  36. description: str = None,
  37. autoselect: bool = False) -> dict:
  38. """
  39. create a collection dict
  40. :param reference: unique reference
  41. :param name: collection name
  42. :param description: collection description
  43. :param autoselect: show this collection by default if it contains elements
  44. :return: collection dict
  45. """
  46. return {
  47. 'reference': reference,
  48. 'name': name,
  49. 'description': description,
  50. 'autoselect': autoselect
  51. }
  52. def execute(self, storage: MediaStorage, file: MediaFile):
  53. """
  54. receive a file, create predictions and add them to the object
  55. :param storage: database abstraction object
  56. :param file: which should be analyzed
  57. """
  58. raise NotImplementedError
  59. def pure_inference(self, storage: MediaStorage, file: MediaFile, bounding_boxes: List[MediaBoundingBox]):
  60. """
  61. receive a file and a list of bounding boxes and only create a
  62. classification for the given bounding boxes.
  63. :param storage: database abstraction object
  64. :param file: which should be analyzed
  65. :param bounding_boxes: only perform inference for the given bounding boxes
  66. :return: labels for the given bounding boxes
  67. """
  68. raise NotImplementedError
  69. def fit(self, storage: MediaStorage):
  70. """
  71. receive a list of annotated media files and adapt the underlying model
  72. :param storage: database abstraction object
  73. """
  74. raise NotImplementedError