import numpy as np from matplotlib import pyplot as plt from contextlib import contextmanager from matplotlib.patches import Rectangle from abc import ABC, abstractmethod, abstractproperty from . import utils class BaseParts(ABC): def __getitem__(self, i): return self._parts[i] @property def selected(self): return np.array([p.is_visible for p in self._parts], dtype=bool) @property def selected_idxs(self): return np.where(self.selected)[0] def select(self, idxs): if isinstance(idxs, np.ndarray) and idxs.dtype == bool: # a mask is present, so convert it to indeces idxs = np.where(idxs)[0] for p in self._parts: p.is_visible = p._id in idxs def invert_selection(self): self.select(np.logical_not(self.selected)) def offset(self, dx, dy): for p in self._parts: p.x += dx p.y += dy def visible_locs(self): vis = [(p._id, p.xy) for p in self._parts if p.is_visible] idxs, xy = zip(*vis) return np.array(idxs), np.array(xy).T def visible_crops(self, *args, **kwargs): return np.array([p.crop(*args, **kwargs) for p in self._parts]) def plot(self, cmap=plt.cm.jet, **kwargs): for i, p in enumerate(self._parts): p.plot(color=cmap(i/len(self._parts)), **kwargs) def reveal(self, im, ratio, *args, **kwargs): res = np.zeros_like(im) for part in self._parts: if not part.is_visible: continue x, y, crop = part.reveal(im, ratio=ratio, *args, **kwargs) h, w, _ = crop.shape res[y:y+h, x:x+w] = crop return res class Parts(BaseParts): def __init__(self, image, part_annotations, rescale_size): super(Parts, self).__init__() self._parts = [BasePart.new(image, a, rescale_size) for a in part_annotations] class UniformParts(BaseParts): def __init__(self, image, ratio): super(UniformParts, self).__init__() self._parts = list(self.generate_parts(image, ratio)) def generate_parts(self, im, ratio, round_op=np.floor): h, w, c = utils.dimensions(im) part_w = round_op(w * ratio).astype(np.int32) part_h = round_op(h * ratio).astype(np.int32) n, m = w // part_w, h // part_h # fit best possible part_w and part_h part_w = int(w / n) part_h = int(h / m) for i in range(n*m): row, col = np.unravel_index(i, (n, m)) x, y = col * part_w, row * part_h yield BBoxPart(im, [i, x, y, part_w, part_h]) class BasePart(ABC): @staticmethod def new(image, annotation, rescale_size=-1): if len(annotation) == 4: return LocationPart(image, annotation, rescale_size) elif len(annotation) == 5: return BBoxPart(image, annotation, rescale_size) else: raise ValueError("Unknown part annotation format: {}".format(annotation)) def rescale(self, image, annotation, rescale_size): if rescale_size is not None and rescale_size > 0: h, w, c = utils.dimensions(image) scale = np.array([w, h]) / rescale_size xy = annotation[1:3] xy = xy * scale annotation[1:3] = xy return annotation @property def is_visible(self): return self._is_visible @is_visible.setter def is_visible(self, value): self._is_visible = bool(value) @property def xy(self): return np.array([self.x, self.y]) def crop(self, im, w, h, padding_mode="edge", is_location=True): if not self.is_visible: _, _, c = utils.dimensions(im) return np.zeros((h, w, c), dtype=np.uint8) x, y = self.xy pad_h, pad_w = h // 2, w // 2 padded_im = np.pad(im, [(pad_h, pad_h), (pad_w, pad_w), [0,0]], mode=padding_mode) x0, y0 = x + pad_w, y + pad_h if is_location: x0, y0 = x0 - w // 2, y0 - h // 2 return padded_im[y0:y0+h, x0:x0+w] def plot(self, **kwargs): return class LocationPart(BasePart): DEFAULT_RATIO = np.sqrt(49 / 400) # 0.35 def __init__(self, image, annotation, rescale_size=None): super(LocationPart, self).__init__() annotation = self.rescale(image, annotation, rescale_size) # here x,y are the center of the part self._id, self.x, self.y, self.is_visible = annotation self._ratio = LocationPart.DEFAULT_RATIO def crop(self, image, ratio=None, padding_mode="edge", *args, **kwargs): ratio = ratio or self._ratio _h, _w, c = utils.dimensions(image) w, h = int(_w * ratio), int(_h * ratio) return super(LocationPart, self).crop(image, w, h, padding_mode, is_location=True) def reveal(self, im, ratio=None, *args, **kwargs): _h, _w, c = utils.dimensions(im) w, h = int(_w * ratio), int(_h * ratio) x,y = self.xy x, y = max(x - w // 2, 0), max(y - h // 2, 0) return x, y, im[y:y+h, x:x+w] def plot(self, im, ax, ratio, fill=False, linestyle="--", **kwargs): if not self.is_visible: return x, y = self.xy _h, _w, c = utils.dimensions(im) w, h = int(_w * ratio), int(_h * ratio) ax.add_patch(Rectangle( (x-w//2, y-h//2), w, h, fill=fill, linestyle=linestyle, **kwargs )) class BBoxPart(BasePart): def __init__(self, image, annotation, rescale_size=None): super(BBoxPart, self).__init__() annotation = self.rescale(image, annotation, rescale_size) # here x,y are top left corner of the part self._id, self.x, self.y, self.w, self.h = annotation self.is_visible = True def rescale(self, image, annotation, rescale_size): if rescale_size is not None and rescale_size > 0: annotation = super(BBoxPart, self).rescale(image, annotation, rescale_size) h, w, c = utils.dimensions(image) scale = np.array([w, h]) / rescale_size wh = annotation[3:5] wh = wh * scale annotation[3:5] = wh return annotation def crop(self, image, padding_mode="edge", *args, **kwargs): return super(BBoxPart, self).crop(image, self.w, self.h, padding_mode, is_location=False) def reveal(self, im, ratio, *args, **kwargs): _h, _w, c = utils.dimensions(im) x,y = self.xy return x, y, im[y:y+self.h, x:x+self.w] def plot(self, im, ax, ratio, fill=False, linestyle="--", **kwargs): ax.add_patch(Rectangle( (self.x, self.y), self.w, self.h, fill=fill, linestyle=linestyle, **kwargs ))