Преглед изворни кода

added possibility for optional annotation files. refactored file reading

Dimitri Korsch пре 5 година
родитељ
комит
3935683c36

+ 1 - 1
cvdatasets/_version.py

@@ -1 +1 @@
-__version__ = "0.6.3"
+__version__ = "0.7.0"

+ 14 - 6
cvdatasets/annotations/base/__init__.py

@@ -49,8 +49,8 @@ class BaseAnnotations(abc.ABC):
 				root_or_infofile
 			))
 
-		for fname, attr in self.meta.structure:
-			self.read_content(fname, attr)
+		for struc in self.meta.structure:
+			self.read_content(*struc)
 
 		self.load()
 
@@ -130,14 +130,16 @@ class BaseAnnotations(abc.ABC):
 	def _open(self, file):
 		return open(self._path(file))
 
-	def read_content(self, file, attr):
+	def set_content_from_file(self, file, attr, reader, optional=False):
 		content = None
 		fpath = self._path(file)
+
 		if isfile(fpath):
 			with self._open(file) as f:
-				content = [line.strip() for line in f if line.strip()]
-		else:
-			msg = "File \"{}\" was not found!".format(fpath)
+				content = reader(f)
+
+		elif not optional:
+			msg = f"File \"{fpath}\" was not found!"
 			if self.load_strict:
 				raise AssertionError(msg)
 			else:
@@ -145,6 +147,12 @@ class BaseAnnotations(abc.ABC):
 
 		setattr(self, attr, content)
 
+	def read_content(self, file, attr, optional=False):
+
+		def reader(f):
+			return [line.strip() for line in f if line.strip()]
+
+		self.set_content_from_file(file, attr, reader, optional)
 
 	def load(self):
 		logging.debug("Loading uuids, labels and training-test split")

+ 1 - 1
cvdatasets/annotations/impl/imagenet.py

@@ -29,7 +29,7 @@ class INET_Annotations(PartsMixin, BaseAnnotations):
 
 		return info
 
-	def read_content(self, folder_name, attr):
+	def read_content(self, folder_name, attr, optional=False):
 		folder_path = self._path(folder_name)
 		logging.info(f"Loading images from folder \"{folder_path}\" ...")
 

+ 39 - 7
cvdatasets/annotations/impl/inat.py

@@ -1,6 +1,9 @@
+import hashlib
+import logging
 import numpy as np
 import simplejson as json
 
+from os.path import isfile
 from os.path import join
 
 from cvdatasets.annotations.base import BaseAnnotations
@@ -9,15 +12,16 @@ from cvdatasets.annotations.base.parts_mixin import PartsMixin
 from cvdatasets.utils import _MetaInfo
 
 
+def _uuid_entry(im_info):
+	return hashlib.md5(im_info["file_name"].encode()).hexdigest()
+
 class BaseINAT_Annotations(BBoxMixin, PartsMixin, BaseAnnotations):
 
-	def read_content(self, json_file, attr):
+	def read_content(self, json_file, attr, optional=False):
 		if not json_file.endswith(".json"):
-			return super(BaseINAT_Annotations, self).read_content(json_file, attr)
-		with self._open(json_file) as f:
-			content = json.load(f)
-			setattr(self, attr, content)
+			return super(BaseINAT_Annotations, self).read_content(json_file, attr, optional=optional)
 
+		self.set_content_from_file(json_file, attr, json.load, optional)
 
 	def _load_bounding_boxes(self):
 		self.bounding_boxes = np.zeros(len(self.uuids), dtype=self.meta.bounding_box_dtype)
@@ -63,6 +67,7 @@ class INAT20_Annotations(BaseINAT_Annotations):
 			images_folder="images",
 			content="trainval.json",
 			val_content="val.json",
+			unlabeled_content="unlabeled_train.json",
 
 			# fake bounding boxes: the whole image
 			bounding_box_dtype=np.dtype([(v, np.int32) for v in "xywh"]),
@@ -74,11 +79,38 @@ class INAT20_Annotations(BaseINAT_Annotations):
 		info.structure = [
 			[info.content, "_content"],
 			[info.val_content, "_val_content"],
-			[info.parts_file, "_part_locs"],
-			[info.part_names_file, "_part_names"],
+			[info.unlabeled_content, "_unlabeled_content", True],
+			[info.parts_file, "_part_locs", True],
+			[info.part_names_file, "_part_names", True],
 		]
 		return info
 
+	@property
+	def has_unlabeled_data(self):
+		return self._unlabeled_content is not None
+
+	def _load_uuids(self, *args, **kwargs):
+		super(INAT20_Annotations, self)._load_uuids(*args, **kwargs)
+
+		if not self.has_unlabeled_data:
+			logging.info("No unlabled data was provided!")
+			return
+
+		logging.info("Loading unlabled data...")
+		uuid_fnames = [(_uuid_entry(im), im["file_name"]) for im in self._unlabeled_content["images"]]
+		self.unlabled_uuids, self.unlabeled_images = map(np.array, zip(*uuid_fnames))
+
+		assert len(np.unique(self.unlabled_uuids)) == len(self.unlabled_uuids), \
+			"Unlabled UUIDs are not unique!"
+
+		overlap = set(self.uuids) & set(self.unlabled_uuids)
+		assert len(overlap) == 0, \
+			f"Unlabled and labeled UUIDs overlap: {overlap}"
+
+		self.unlabled_uuid_to_idx = {uuid: i for i, uuid in enumerate(self.unlabled_uuids)}
+
+
+
 class INAT19_Annotations(BaseINAT_Annotations):
 
 	name="INAT19"

+ 1 - 1
scripts/display.py

@@ -19,7 +19,7 @@ def main(args):
 	annotation_cls = AnnotationType[args.dataset].value
 
 	logging.info(f"Loading \"{args.dataset}\" annnotations from \"{args.data}\"")
-	annot = annotation_cls.new(args)
+	annot = annotation_cls(root_or_infofile=args.data, parts=args.parts, load_strict=False)
 
 	kwargs = {}
 	if annot.info is None:

+ 3 - 2
scripts/utils/parser.py

@@ -103,6 +103,7 @@ def parse_args():
 			action="store_true"),
 	], group_name="Display options")
 
-	parser.add_args([Arg('--seed', type=int, default=12311123, help='random seed')])
-	parser.init_logger()
+	parser.add_args([
+		Arg('--seed', type=int, default=12311123, help='random seed')
+	])
 	return parser.parse_args()