MediaFile.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. from typing import List
  2. from typing import Optional
  3. from typing import Union
  4. from pycs.database.File import File
  5. from pycs.database.Result import Result
  6. from pycs.frontend.notifications.NotificationList import NotificationList
  7. from pycs.interfaces.MediaBoundingBox import MediaBoundingBox
  8. from pycs.interfaces.MediaImageLabel import MediaImageLabel
  9. from pycs.interfaces.MediaLabel import MediaLabel
  10. class MediaFile:
  11. """
  12. contains various attributes of a saved media file
  13. """
  14. def __init__(self, file: File, notifications: NotificationList):
  15. self.__file = file
  16. self.__notifications = notifications
  17. self.type = file.type
  18. self.size = file.size
  19. self.frames = file.frames
  20. self.fps = file.fps
  21. self.path = file.absolute_path
  22. def set_collection(self, reference: Optional[str]):
  23. """
  24. set this file's collection
  25. :param reference: use None to remove this file's collection
  26. """
  27. self.__file.set_collection_by_reference(reference)
  28. self.__notifications.add("edit", "file", self.__file.id, cls=File)
  29. def set_image_label(self, label: Union[int, MediaLabel], frame: int = None):
  30. """
  31. create a labeled-image result
  32. :param label: label id
  33. :param frame: frame index (only set for videos)
  34. """
  35. if label is not None and isinstance(label, MediaLabel):
  36. label = label.id
  37. if frame is not None:
  38. data = {'frame': frame}
  39. else:
  40. data = None
  41. created = self.__file.create_result('pipeline', 'labeled-image', label, data)
  42. self.__notifications.add("create", "result", created.id, cls=Result)
  43. def add_bounding_box(self, x: float, y: float, w: float, h: float,
  44. label: Union[int, MediaLabel] = None, frame: int = None):
  45. """
  46. create a bounding-box result
  47. :param x: relative x coordinate [0, 1]
  48. :param y: relative y coordinate [0, 1]
  49. :param w: relative width [0, 1]
  50. :param h: relative height [0, 1]
  51. :param label: label
  52. :param frame: frame index (only set for videos)
  53. """
  54. result = {
  55. 'x': x,
  56. 'y': y,
  57. 'w': w,
  58. 'h': h
  59. }
  60. if frame is not None:
  61. result['frame'] = frame
  62. if label is not None and isinstance(label, MediaLabel):
  63. label = label.id
  64. created = self.__file.create_result('pipeline', 'bounding-box', label, result)
  65. self.__notifications.add("create", "result", created.id, cls=Result)
  66. def remove_predictions(self):
  67. """
  68. remove and return all predictions added from pipelines
  69. """
  70. removed = self.__file.remove_results(origin='pipeline')
  71. for r in removed:
  72. self.__notifications.add("remove", "result", r.serialize())
  73. def __get_results(self, origin: str) -> List[Union[MediaImageLabel, MediaBoundingBox]]:
  74. def map_r(result: Result) -> Union[MediaImageLabel, MediaBoundingBox]:
  75. if result.type == 'labeled-image':
  76. return MediaImageLabel(result)
  77. return MediaBoundingBox(result)
  78. return list(map(map_r, self.__file.results.filter_by(origin=origin)))
  79. def results(self) -> List[Union[MediaImageLabel, MediaBoundingBox]]:
  80. """
  81. receive results added by users
  82. :return: list of results
  83. """
  84. return self.__get_results('user')
  85. def predictions(self) -> List[Union[MediaImageLabel, MediaBoundingBox]]:
  86. """
  87. receive results added by pipelines
  88. :return: list of predictions
  89. """
  90. return self.__get_results('pipeline')
  91. def serialize(self) -> dict:
  92. """
  93. serialize all object properties to a dict
  94. :return: dict
  95. """
  96. return {
  97. 'type': self.type,
  98. 'size': self.size,
  99. 'frames': self.frames,
  100. 'fps': self.fps,
  101. 'path': self.path,
  102. 'filename': self.__file.name + self.__file.extension,
  103. 'results': list(map(lambda r: r.serialize(), self.results())),
  104. 'predictions': list(map(lambda r: r.serialize(), self.predictions())),
  105. }