6
0

EstimateBoundingBox.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import typing as T
  2. import uuid
  3. import cv2
  4. import numpy as np
  5. from flask import abort
  6. from flask import make_response
  7. from flask import request
  8. from flask.views import View
  9. from pycs.database.File import File
  10. from pycs.database.Result import Result
  11. from pycs.frontend.notifications.NotificationManager import NotificationManager
  12. from pycs.jobs.JobGroupBusyException import JobGroupBusyException
  13. from pycs.jobs.JobRunner import JobRunner
  14. class EstimateBoundingBox(View):
  15. """
  16. create a result for a file
  17. """
  18. # pylint: disable=arguments-differ
  19. methods = ['POST']
  20. def __init__(self, nm: NotificationManager, jobs: JobRunner,):
  21. # pylint: disable=invalid-name
  22. self.nm = nm
  23. self.jobs = jobs
  24. def dispatch_request(self, file_id: int):
  25. file = File.get_or_404(file_id)
  26. request_data = request.get_json(force=True)
  27. if 'x' not in request_data or 'y' not in request_data:
  28. abort(400, "coordinates for the estimation are missing")
  29. x, y = map(request_data.get, "xy")
  30. # get project
  31. project = file.project
  32. try:
  33. rnd = str(uuid.uuid4())[:10]
  34. self.jobs.run(project,
  35. "Estimation",
  36. f'{project.name} (create predictions)',
  37. f"{project.id}/estimation/{rnd}",
  38. estimate,
  39. file.id, x, y,
  40. result=self.nm.create_result
  41. )
  42. except JobGroupBusyException:
  43. abort(400, "Job is already running!")
  44. return make_response()
  45. def estimate(file_id: int, x: float, y: float) -> Result:
  46. """ estimation function """
  47. file = File.query.get(file_id)
  48. image = cv2.imread(file.absolute_path, cv2.IMREAD_GRAYSCALE)
  49. h, w = image.shape
  50. pos = int(x * w), int(y * h)
  51. x0, y0, x1, y1 = detect(image, pos,
  52. window_size=1000,
  53. pixel_delta=50,
  54. enlarge=1e-2,
  55. )
  56. data = dict(
  57. x=x0 / w,
  58. y=y0 / h,
  59. w=(x1-x0) / w,
  60. h=(y1-y0) / h
  61. )
  62. return file.create_result('pipeline', 'bounding-box', label=None, data=data)
  63. def detect(image: np.ndarray,
  64. pos: T.Tuple[int, int],
  65. window_size: int = 1000,
  66. pixel_delta: int = 0,
  67. enlarge: float = -1) -> T.Tuple[int, int, int, int]:
  68. """ detection function """
  69. # image = blur(image, 3)
  70. x, y = pos
  71. pixel = image[y, x]
  72. min_pix, max_pix = pixel - pixel_delta, pixel + pixel_delta
  73. mask = np.logical_and(min_pix < image, image < max_pix).astype(np.float32)
  74. # mask = open_close(mask)
  75. # mask = blur(mask)
  76. pad = window_size // 2
  77. mask = np.pad(mask, pad, mode="constant")
  78. window = mask[y: y + window_size, x: x + window_size]
  79. sum_x, sum_y = window.sum(axis=0), window.sum(axis=1)
  80. enlarge = int(enlarge * max(image.shape))
  81. (x0, x1), (y0, y1) = get_borders(sum_x, enlarge), get_borders(sum_y, enlarge)
  82. x0 = max(x + x0 - pad, 0)
  83. y0 = max(y + y0 - pad, 0)
  84. x1 = min(x + x1 - pad, image.shape[1])
  85. y1 = min(y + y1 - pad, image.shape[0])
  86. return x0, y0, x1, y1
  87. def get_borders(arr, enlarge: int, eps=5e-1):
  88. """ returns borders based on coordinate extrema """
  89. mid = len(arr) // 2
  90. arr0, arr1 = arr[:mid], arr[mid:]
  91. thresh = arr[mid] * eps
  92. lowers = np.where(arr0 < thresh)[0]
  93. lower = 0 if len(lowers) == 0 else lowers[-1]
  94. uppers = np.where(arr1 < thresh)[0]
  95. upper = arr1.argmin() if len(uppers) == 0 else uppers[0]
  96. # since the second half starts after the first
  97. upper = len(arr0) + upper
  98. if enlarge > 0:
  99. lower = max(lower - enlarge, 0)
  100. upper = min(upper + enlarge, len(arr)-1)
  101. return int(lower), int(upper)