6
0

__init__.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. from typing import List
  2. import cv2
  3. import numpy as np
  4. from json import dump, load
  5. from pycs.interfaces.MediaFile import MediaFile
  6. from pycs.interfaces.MediaBoundingBox import MediaBoundingBox
  7. from pycs.interfaces.MediaStorage import MediaStorage
  8. from pycs.interfaces.Pipeline import Pipeline as Interface
  9. from blob_detector.core.bbox import BBox
  10. from .detector import Detector
  11. from .classifier import Classifier
  12. class Scanner(Interface):
  13. def __init__(self, root_folder: str, configuration: dict):
  14. super().__init__(root_folder, configuration)
  15. self.detector = Detector(configuration["detector"])
  16. self.classifier = None #Classifier(configuration["classifier"], root=root_folder)
  17. def close(self):
  18. pass
  19. def predict(self, im, bboxes, storage):
  20. for bbox in bboxes:
  21. if not bbox.is_valid:
  22. continue
  23. if self.classifier is None:
  24. yield bbox, None
  25. continue
  26. x0, y0, x1, y1 = bbox
  27. cls_ref = self.classifier(bbox.crop(im, enlarge=True))
  28. label = labels.get(cls_ref, cls_ref)
  29. yield bbox, label
  30. def execute(self, storage: MediaStorage, file: MediaFile):
  31. im = self.read_image(file.path)
  32. im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
  33. bw_im = cv2.cvtColor(im, cv2.COLOR_RGB2GRAY)
  34. bboxes = self.detector(bw_im)
  35. labels = {ml.reference: ml for ml in storage.labels()}
  36. for bbox, label in self.predict(im, bboxes, storage):
  37. file.add_bounding_box(
  38. bbox.x0, bbox.y0,
  39. bbox.w, bbox.h, label=label)
  40. def pure_inference(self, storage: MediaStorage, file: MediaFile, bounding_boxes: List[MediaBoundingBox]):
  41. im = self.read_image(file.path)
  42. im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
  43. bboxes = [BBox(bbox.x, bbox.y, bbox.x + bbox.w , bbox.y + bbox.h)
  44. for bbox in bounding_boxes]
  45. bbox_labels = []
  46. for bbox, label in self.predict(im, bboxes, storage):
  47. bbox_labels.append(label)
  48. return bbox_labels
  49. def read_image(self, path: str, mode: int = cv2.IMREAD_COLOR) -> np.ndarray:
  50. return cv2.imread(path, mode)