__init__.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. from typing import List
  2. import cv2
  3. import numpy as np
  4. from json import dump, load
  5. from pycs.interfaces.MediaFile import MediaFile
  6. from pycs.interfaces.MediaBoundingBox import MediaBoundingBox
  7. from pycs.interfaces.MediaStorage import MediaStorage
  8. from pycs.interfaces.Pipeline import Pipeline as Interface
  9. from .detector import Detector
  10. from .classifier import Classifier
  11. from .detector import BBox
  12. class Scanner(Interface):
  13. def __init__(self, root_folder: str, configuration: dict):
  14. super().__init__(root_folder, configuration)
  15. self.detector = Detector(configuration["detector"])
  16. self.classifier = Classifier(configuration["classifier"], root=root_folder)
  17. def close(self):
  18. pass
  19. def execute(self, storage: MediaStorage, file: MediaFile):
  20. im = self.read_image(file.path)
  21. im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
  22. bw_im = cv2.cvtColor(im, cv2.COLOR_RGB2GRAY)
  23. detections = self.detector(bw_im)
  24. labels = {ml.reference: ml for ml in storage.labels()}
  25. for bbox, info in detections:
  26. if not info.selected:
  27. continue
  28. x0, y0, x1, y1 = bbox
  29. cls_ref = self.classifier(bbox.crop(im, enlarge=True))
  30. label = labels.get(cls_ref, cls_ref)
  31. file.add_bounding_box(x0, y0, bbox.w, bbox.h, label=label)
  32. def pure_inference(self, storage: MediaStorage, file: MediaFile, bounding_boxes: List[MediaBoundingBox]):
  33. im = self.read_image(file.path)
  34. im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
  35. bw_im = cv2.cvtColor(im, cv2.COLOR_RGB2GRAY)
  36. labels = {ml.reference: ml for ml in storage.labels()}
  37. bbox_labels = []
  38. for bbox in bounding_boxes:
  39. bbox = BBox(bbox.x, bbox.y, bbox.x + bbox.w , bbox.y + bbox.h)
  40. x0, y0, x1, y1 = bbox
  41. cls_ref = self.classifier(bbox.crop(im, enlarge=True))
  42. label = labels.get(cls_ref, cls_ref)
  43. bbox_labels.append(label)
  44. return bbox_labels
  45. def read_image(self, path: str, mode: int = cv2.IMREAD_COLOR) -> np.ndarray:
  46. return cv2.imread(path, mode)