Selaa lähdekoodia

fixed the setting of part surrogates

Dimitri Korsch 5 vuotta sitten
vanhempi
commit
0c6188e695

+ 14 - 4
cvdatasets/dataset/image.py

@@ -1,12 +1,14 @@
-from imageio import imread
 from PIL import Image
+from imageio import imread
 from os.path import isfile
 
 import copy
 import numpy as np
 
+from .part import Parts
+from .part import UniformParts
+from .part import SurrogateType
 from cvdatasets import utils
-from .part import Parts, UniformParts
 
 def should_have_parts(func):
 	def inner(self, *args, **kwargs):
@@ -32,7 +34,12 @@ class ImageWrapper(object):
 			raise RuntimeError("Reading image \"{}\" failed after {} n_retries! ({})".format(im_path, n_retries, error))
 
 
-	def __init__(self, im_path, label, parts=None, mode="RGB", part_rescale_size=None, center_cropped=True):
+	def __init__(self, im_path, label,
+		parts=None,
+		mode="RGB",
+		part_rescale_size=None,
+		part_surrogate_type=SurrogateType.IMAGE,
+		center_cropped=True):
 
 		self.mode = mode
 		self._im = None
@@ -41,7 +48,10 @@ class ImageWrapper(object):
 		self.im = im_path
 
 		self.label = label
-		self.parts = Parts(self.im, parts, part_rescale_size, center_cropped)
+		self.parts = Parts(self.im, parts,
+			rescale_size=part_rescale_size,
+			surrogate_type=part_surrogate_type,
+			center_cropped=center_cropped)
 
 		self.parent = None
 		self._feature = None

+ 3 - 1
cvdatasets/dataset/part/__init__.py

@@ -1 +1,3 @@
-from .collection import Parts, UniformParts
+from .collection import Parts
+from .collection import UniformParts
+from .surrogate import SurrogateType

+ 1 - 1
cvdatasets/dataset/part/base.py

@@ -114,7 +114,7 @@ class BasePart(ABC):
 		if not self.is_visible:
 			surrogate = self._surrogate_type
 			if surrogate is not None and callable(surrogate):
-				return surrogate(im, w, h, im.dtype)
+				return surrogate(im, w, h, dtype=np.uint8)
 			else:
 				warnings.warn("Part surrogate was not set, but is needed! Returning blank patch as fallback.")
 				_, _, c = utils.dimensions(im)