Эх сурвалжийг харах

propagating the information about center crop to the rescale function

Dimitri Korsch 6 жил өмнө
parent
commit
e63bc75319

+ 2 - 2
cvdatasets/dataset/image.py

@@ -22,14 +22,14 @@ class ImageWrapper(object):
 		return im
 
 
-	def __init__(self, im_path, label, parts=None, mode="RGB", part_rescale_size=None):
+	def __init__(self, im_path, label, parts=None, mode="RGB", part_rescale_size=None, center_cropped=True):
 
 		self.mode = mode
 		self.im = im_path
 		self._im_array = None
 
 		self.label = label
-		self.parts = Parts(self.im, parts, part_rescale_size)
+		self.parts = Parts(self.im, parts, part_rescale_size, center_cropped)
 
 		self.parent = None
 		self._feature = None

+ 6 - 2
cvdatasets/dataset/mixins/reading.py

@@ -7,12 +7,13 @@ from ..image import ImageWrapper
 
 class AnnotationsReadMixin(BaseMixin):
 
-	def __init__(self, uuids, annotations, part_rescale_size=None, mode="RGB"):
+	def __init__(self, uuids, annotations, part_rescale_size=None, center_cropped=True, mode="RGB"):
 		super(AnnotationsReadMixin, self).__init__()
 		self.uuids = uuids
 		self._annot = annotations
 		self.mode = mode
 		self.part_rescale_size = part_rescale_size
+		self.center_cropped = center_cropped
 
 	def __len__(self):
 		return len(self.uuids)
@@ -28,7 +29,10 @@ class AnnotationsReadMixin(BaseMixin):
 		methods = ["image", "parts", "label"]
 		im_path, parts, label = [self._get(m, i) for m in methods]
 
-		return ImageWrapper(im_path, int(label), parts, mode=self.mode, part_rescale_size=self.part_rescale_size)
+		return ImageWrapper(im_path, int(label), parts,
+			mode=self.mode,
+			part_rescale_size=self.part_rescale_size,
+			center_cropped=self.center_cropped)
 
 	@property
 	def n_parts(self):

+ 1 - 1
cvdatasets/dataset/part/base.py

@@ -83,7 +83,7 @@ class BasePart(ABC):
 		if len(annotation) == 4:
 			return LocationPart(image, annotation, rescale_size)
 		elif len(annotation) == 5:
-			return BBoxPart(image, annotation, rescale_size)
+			return BBoxPart(image, annotation, rescale_size, center_cropped)
 		else:
 			raise ValueError("Unknown part annotation format: {}".format(annotation))
 

+ 3 - 3
cvdatasets/dataset/part/collection.py

@@ -7,12 +7,12 @@ from .annotation import BBoxPart
 
 class Parts(BasePartCollection):
 
-	def __init__(self, image, part_annotations, rescale_size):
+	def __init__(self, image, part_annotations, *args, **kwargs):
 		super(Parts, self).__init__()
 		if part_annotations is None:
 			self._parts = []
 		else:
-			self._parts = [BasePart.new(image, a, rescale_size) for a in part_annotations]
+			self._parts = [BasePart.new(image, a, *args, **kwargs) for a in part_annotations]
 
 
 class UniformParts(BasePartCollection):
@@ -37,4 +37,4 @@ class UniformParts(BasePartCollection):
 			row, col = np.unravel_index(i, (n, m))
 			x, y = col * part_w, row * part_h
 
-			yield BBoxPart(im, [i, x, y, part_w, part_h])
+			yield BBoxPart(im, [i, x, y, part_w, part_h], center_cropped=False)