Răsfoiți Sursa

fixed feature loading

Dimitri Korsch 6 ani în urmă
părinte
comite
ea575df2a3
1 a modificat fișierele cu 19 adăugiri și 12 ștergeri
  1. 19 12
      nabirds/annotations/base.py

+ 19 - 12
nabirds/annotations/base.py

@@ -1,8 +1,9 @@
-from os.path import join, isfile, isdir
 import numpy as np
-from collections import defaultdict, OrderedDict
 import abc
 import warnings
+import logging
+from os.path import join, isfile, isdir
+from collections import defaultdict, OrderedDict
 
 try:
 	from yaml import CLoader as Loader, CDumper as Dumper
@@ -10,13 +11,14 @@ except ImportError:
 	from yaml import Loader, Dumper
 
 import yaml
-import simplejson as json
 
 from nabirds.utils import attr_dict
 from nabirds.dataset import Dataset
 
 class BaseAnnotations(abc.ABC):
 
+	FEATURE_PHONY = dict(train=["train"], test=["test", "val"])
+
 	def __init__(self, root_or_infofile, parts=None, feature_model=None):
 		super(BaseAnnotations, self).__init__()
 		self.part_type = parts
@@ -59,7 +61,6 @@ class BaseAnnotations(abc.ABC):
 			self.info = attr_dict(yaml.load(f, Loader=Loader))
 
 		dataset_info = self.dataset_info
-		# print(json.dumps(dataset_info, indent=2))
 		annot_dir = join(self.data_root, dataset_info.folder, dataset_info.annotations)
 
 		assert isdir(annot_dir), "Annotation folder does exist! \"{}\"".format(annot_dir)
@@ -89,17 +90,23 @@ class BaseAnnotations(abc.ABC):
 			new_opts["part_rescale_size"] = dataset_info.rescale_size
 
 		if None not in [subset, self.feature_model]:
-			features = "{subset}_{suffix}.{model}.npz".format(
-				subset=subset,
-				suffix=dataset_info.feature_suffix,
-				model=self.feature_model)
-			feature_path = join(self.root, "features", features)
-			assert isfile(feature_path), \
-				"Features do not exist: \"{}\"".format(feature_path)
+			tried = []
+			for subset_phony in BaseAnnotations.FEATURE_PHONY[subset]:
+				features = "{subset}_{suffix}.{model}.npz".format(
+					subset=subset_phony,
+					suffix=dataset_info.feature_suffix,
+					model=self.feature_model)
+				feature_path = join(self.root, "features", features)
+				if isfile(feature_path): break
+				tried.append(feature_path)
+			else:
+				raise ValueError(
+					"Could not find any features in \"{}\" for {} subset. Tried features: {}".format(
+					join(self.root, "features"), subset, tried))
 			new_opts["features"] = feature_path
 		new_opts.update(kwargs)
 
-		print(new_opts)
+		logging.debug(new_opts)
 		return new_opts
 
 	@property