瀏覽代碼

Added feature reading mixin

Dimitri Korsch 6 年之前
父節點
當前提交
48cc555994

+ 1 - 1
nabirds/__init__.py

@@ -1,4 +1,4 @@
 from .dataset import Dataset
 from .annotations import NAB_Annotations, CUB_Annotations
 
-__version__ = "0.1.7"
+__version__ = "0.1.8"

+ 2 - 1
nabirds/dataset/__init__.py

@@ -1,7 +1,8 @@
 from .mixins.reading import AnnotationsReadMixin, ImageListReadingMixin
 from .mixins.parts import PartMixin, RevealedPartMixin, CroppedPartMixin
+from .mixins.features import PreExtractedFeaturesMixin
 
-class Dataset(PartMixin, AnnotationsReadMixin):
+class Dataset(PartMixin, PreExtractedFeaturesMixin, AnnotationsReadMixin):
 
 	def get_example(self, i):
 		im_obj = super(Dataset, self).get_example(i)

+ 8 - 0
nabirds/dataset/image.py

@@ -24,6 +24,7 @@ class ImageWrapper(object):
 		self.parts = parts
 
 		self.parent = None
+		self._feature = None
 
 	def as_tuple(self):
 		return self.im, self.parts, self.label
@@ -33,6 +34,13 @@ class ImageWrapper(object):
 		new.parent = self
 		return new
 
+	@property
+	def feature(self):
+		return self._feature
+
+	@feature.setter
+	def feature(self, im_feature):
+		self._feature = im_feature
 
 	def crop(self, x, y, w, h):
 		result = self.copy()

+ 43 - 0
nabirds/dataset/mixins/features.py

@@ -0,0 +1,43 @@
+import numpy as np
+
+from os.path import isfile
+
+from . import BaseMixin
+
+
+class PreExtractedFeaturesMixin(BaseMixin):
+
+	def __size_check(self):
+		assert len(self.features) == len(self), \
+			"Number of features ({}) does not match the number of images ({})!".format(
+				len(self.features), len(self)
+			)
+
+	def __init__(self, features=None, *args, **kw):
+		super(PreExtractedFeaturesMixin, self).__init__(*args, **kw)
+
+		self.features = None
+		if features is not None and isfile(features):
+			self.features = self.load_features(features)
+			self.__size_check()
+
+	def load_features(self, features_file):
+		"""
+			Default feature loading from a file.
+			If you desire another feature loading logic,
+			subclass this mixin and override this method.
+		"""
+		try:
+			cont = np.load(features_file)
+			return cont["features"]
+		except Exception as e:
+			msg = "Error occured while reading features: \"{}\". ".format(e) + \
+				"If you want another feature loading logic, override this method!"
+			raise ValueError(msg)
+
+	def get_example(self, i):
+		im_obj = super(PreExtractedFeaturesMixin, self).get_example(i)
+		if self.features is not None:
+			im_obj.feature = self.features[i]
+
+		return im_obj

+ 1 - 0
nabirds/dataset/mixins/reading.py

@@ -29,6 +29,7 @@ class AnnotationsReadMixin(BaseMixin):
 
 
 class ImageListReadingMixin(BaseMixin):
+
 	def __init__(self, pairs, root="."):
 		super(ImageListReadingMixin, self).__init__()
 		with open(pairs) as f:

+ 12 - 2
nabirds/display.py

@@ -51,9 +51,14 @@ def main(args):
 	logging.info("Loading \"{}\" annnotations from \"{}\"".format(args.dataset, args.data))
 	annot = annotation_cls.get(args.dataset.lower())(args.data)
 
-	uuids = getattr(annot, "{}_uuids".format(args.subset.lower()))
+	subset = args.subset.lower()
+
+	uuids = getattr(annot, "{}_uuids".format(subset))
+	features = args.features[0 if subset == "train" else 1]
+
 	data = Dataset(
 		uuids=uuids, annotations=annot,
+		features=features,
 
 		uniform_parts=args.uniform_parts,
 
@@ -68,7 +73,7 @@ def main(args):
 
 	)
 	n_images = len(data)
-	logging.info("Found {} images in the {} subset".format(n_images, args.subset))
+	logging.info("Found {} images in the {} subset".format(n_images, subset))
 
 	for i in range(n_images):
 		if i + 1 <= args.start: continue
@@ -120,6 +125,11 @@ parser.add_argument("--dataset",
 	choices=["cub", "nab"],
 	default="nab", type=str)
 
+parser.add_argument("--features",
+	help="pre-extracted train and test features",
+	default=[None, None],
+	nargs=2, type=str)
+
 parser.add_argument("--subset",
 	help="Possible subsets: train, test",
 	choices=["train", "test"],