Procházet zdrojové kódy

added reading of unlabeled data

Dimitri Korsch před 5 roky
rodič
revize
656dc6ea40
1 změnil soubory, kde provedl 15 přidání a 8 odebrání
  1. 15 8
      cvdatasets/annotations/impl/inat.py

+ 15 - 8
cvdatasets/annotations/impl/inat.py

@@ -1,3 +1,4 @@
+import copy
 import hashlib
 import logging
 import numpy as np
@@ -93,21 +94,27 @@ class INAT20_Annotations(BaseINAT_Annotations):
 		super(INAT20_Annotations, self)._load_uuids(*args, **kwargs)
 
 		if not self.has_unlabeled_data:
-			logging.info("No unlabled data was provided!")
+			logging.info("No unlabeled data was provided!")
 			return
 
-		logging.info("Loading unlabled data...")
+		logging.info("Loading unlabeled data...")
 		uuid_fnames = [(_uuid_entry(im), im["file_name"]) for im in self._unlabeled_content["images"]]
-		self.unlabled_uuids, self.unlabeled_images = map(np.array, zip(*uuid_fnames))
 
-		assert len(np.unique(self.unlabled_uuids)) == len(self.unlabled_uuids), \
-			"Unlabled UUIDs are not unique!"
+		self.unlabeled = unlabeled = copy.copy(self)
 
-		overlap = set(self.uuids) & set(self.unlabled_uuids)
+		unlabeled.uuids, unlabeled.images = map(np.array, zip(*uuid_fnames))
+		unlabeled.labels = np.full(unlabeled.images.shape, -1, dtype=np.int32)
+		unlabeled.train_split = np.full(unlabeled.images.shape, 1, dtype=bool)
+		unlabeled.test_split = np.full(unlabeled.images.shape, 0, dtype=bool)
+
+		assert len(np.unique(unlabeled.uuids)) == len(unlabeled.uuids), \
+			"Unlabeled UUIDs are not unique!"
+
+		overlap = set(self.uuids) & set(unlabeled.uuids)
 		assert len(overlap) == 0, \
-			f"Unlabled and labeled UUIDs overlap: {overlap}"
+			f"Unlabeled and labeled UUIDs overlap: {overlap}"
 
-		self.unlabled_uuid_to_idx = {uuid: i for i, uuid in enumerate(self.unlabled_uuids)}
+		unlabeled.uuid_to_idx = {uuid: i for i, uuid in enumerate(unlabeled.uuids)}