MediaFile.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. from os import path, getcwd
  2. from typing import Optional, List, Union
  3. from pycs.database.File import File
  4. from pycs.database.Result import Result
  5. from pycs.frontend.notifications.NotificationList import NotificationList
  6. from pycs.interfaces.MediaBoundingBox import MediaBoundingBox
  7. from pycs.interfaces.MediaImageLabel import MediaImageLabel
  8. from pycs.interfaces.MediaLabel import MediaLabel
  9. class MediaFile:
  10. """
  11. contains various attributes of a saved media file
  12. """
  13. def __init__(self, file: File, notifications: NotificationList):
  14. self.__file = file
  15. self.__notifications = notifications
  16. self.type = file.type
  17. self.size = file.size
  18. self.frames = file.frames
  19. self.fps = file.fps
  20. if path.isabs(file.path):
  21. self.path = file.path
  22. else:
  23. self.path = path.join(getcwd(), file.path)
  24. def set_collection(self, reference: Optional[str]):
  25. """
  26. set this file's collection
  27. :param reference: use None to remove this file's collection
  28. """
  29. self.__file.set_collection_by_reference(reference)
  30. self.__notifications.add(self.__notifications.nm.edit_file, self.__file)
  31. def set_image_label(self, label: Union[int, MediaLabel]):
  32. """
  33. create a labeled-image result
  34. :param label: label identifier
  35. """
  36. if label is not None and isinstance(label, MediaLabel):
  37. label = label.identifier
  38. created = self.__file.create_result('pipeline', 'labeled-image', label)
  39. self.__notifications.add(self.__notifications.nm.create_result, created)
  40. def add_bounding_box(self, x: float, y: float, w: float, h: float,
  41. label: Union[int, MediaLabel] = None, frame: int = None):
  42. """
  43. create a bounding-box result
  44. :param x: relative x coordinate [0, 1]
  45. :param y: relative y coordinate [0, 1]
  46. :param w: relative width [0, 1]
  47. :param h: relative height [0, 1]
  48. :param label: label
  49. :param frame: frame index
  50. """
  51. result = {
  52. 'x': x,
  53. 'y': y,
  54. 'w': w,
  55. 'h': h
  56. }
  57. if frame is not None:
  58. result['frame'] = frame
  59. if label is not None and isinstance(label, MediaLabel):
  60. label = label.identifier
  61. created = self.__file.create_result('pipeline', 'bounding-box', label, result)
  62. self.__notifications.add(self.__notifications.nm.create_result, created)
  63. def remove_predictions(self):
  64. """
  65. remove and return all predictions added from pipelines
  66. """
  67. removed = self.__file.remove_results(origin='pipeline')
  68. for r in removed:
  69. self.__notifications.add(self.__notifications.nm.remove_result, r)
  70. def __get_results(self, origin: str) -> List[Union[MediaImageLabel, MediaBoundingBox]]:
  71. def map_r(result: Result) -> Union[MediaImageLabel, MediaBoundingBox]:
  72. if result.type == 'labeled-image':
  73. return MediaImageLabel(result)
  74. else:
  75. return MediaBoundingBox(result)
  76. return list(map(map_r,
  77. filter(lambda r: r.origin == origin,
  78. self.__file.results())))
  79. def results(self) -> List[Union[MediaImageLabel, MediaBoundingBox]]:
  80. """
  81. receive results added by users
  82. :return: list of results
  83. """
  84. return self.__get_results('user')
  85. def predictions(self) -> List[Union[MediaImageLabel, MediaBoundingBox]]:
  86. """
  87. receive results added by pipelines
  88. :return: list of predictions
  89. """
  90. return self.__get_results('pipeline')
  91. def serialize(self) -> dict:
  92. return {
  93. 'type': self.type,
  94. 'size': self.size,
  95. 'frames': self.frames,
  96. 'fps': self.fps,
  97. 'path': self.path,
  98. 'results': list(map(lambda r: r.serialize(), self.results())),
  99. 'predictions': list(map(lambda r: r.serialize(), self.predictions())),
  100. }