|
@@ -5,6 +5,16 @@ from cvdatasets.annotation.files import AnnotationFiles
|
|
|
|
|
|
class FileListAnnotations(Annotations):
|
|
|
|
|
|
+ @classmethod
|
|
|
+ def extract_kwargs(cls, opts, ds_info=None, *args, **kwargs):
|
|
|
+ kwargs = super(FileListAnnotations, cls).extract_kwargs(opts, ds_info=ds_info, *args, **kwargs)
|
|
|
+ kwargs["test_fold_id"] = getattr(opts, "test_fold_id", 0)
|
|
|
+ return kwargs
|
|
|
+
|
|
|
+ def __init__(self, *args, test_fold_id=0, **kwargs):
|
|
|
+ self._test_fold_id = test_fold_id
|
|
|
+ super(FileListAnnotations, self).__init__(*args, **kwargs)
|
|
|
+
|
|
|
def load_files(self, file_obj) -> AnnotationFiles:
|
|
|
file_obj.load_files("images.txt", "labels.txt", "tr_ID.txt")
|
|
|
return file_obj
|
|
@@ -30,10 +40,11 @@ class FileListAnnotations(Annotations):
|
|
|
|
|
|
assert hasattr(self, "uuids"), \
|
|
|
"UUIDs were not parsed yet! Please call _parse_uuids before this method!"
|
|
|
-
|
|
|
uuid_to_split = {uuid: int(split) for uuid, split in zip(self.uuids, self.files.tr_ID)}
|
|
|
- self.train_split = np.array([uuid_to_split[uuid] for uuid in self.uuids], dtype=bool)
|
|
|
- self.test_split = np.logical_not(self.train_split)
|
|
|
+
|
|
|
+ split_ids = np.array([uuid_to_split[uuid] for uuid in self.uuids])
|
|
|
+ self.test_split = split_ids == self._test_fold_id
|
|
|
+ self.train_split = np.logical_not(self.test_split)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
annot = FileListAnnotations(
|