Dimitri Korsch 4 жил өмнө
parent
commit
d6cadd2ebd

+ 14 - 3
cvdatasets/annotation/types/file_list.py

@@ -5,6 +5,16 @@ from cvdatasets.annotation.files import AnnotationFiles
 
 class FileListAnnotations(Annotations):
 
+	@classmethod
+	def extract_kwargs(cls, opts, ds_info=None, *args, **kwargs):
+		kwargs = super(FileListAnnotations, cls).extract_kwargs(opts, ds_info=ds_info, *args, **kwargs)
+		kwargs["test_fold_id"] = getattr(opts, "test_fold_id", 0)
+		return kwargs
+
+	def __init__(self, *args, test_fold_id=0, **kwargs):
+		self._test_fold_id = test_fold_id
+		super(FileListAnnotations, self).__init__(*args, **kwargs)
+
 	def load_files(self, file_obj) -> AnnotationFiles:
 		file_obj.load_files("images.txt", "labels.txt", "tr_ID.txt")
 		return file_obj
@@ -30,10 +40,11 @@ class FileListAnnotations(Annotations):
 
 		assert hasattr(self, "uuids"), \
 			"UUIDs were not parsed yet! Please call _parse_uuids before this method!"
-
 		uuid_to_split = {uuid: int(split) for uuid, split in zip(self.uuids, self.files.tr_ID)}
-		self.train_split = np.array([uuid_to_split[uuid] for uuid in self.uuids], dtype=bool)
-		self.test_split = np.logical_not(self.train_split)
+
+		split_ids = np.array([uuid_to_split[uuid] for uuid in self.uuids])
+		self.test_split = split_ids == self._test_fold_id
+		self.train_split = np.logical_not(self.test_split)
 
 if __name__ == '__main__':
 	annot = FileListAnnotations(