瀏覽代碼

added new anntotation type for iNat dataset

Dimitri Korsch 5 年之前
父節點
當前提交
c7d4b59ace
共有 3 個文件被更改,包括 21 次插入2 次删除
  1. 1 0
      cvdatasets/annotations/__init__.py
  2. 12 2
      cvdatasets/utils/dataset.py
  3. 8 0
      scripts/info_files/info.yml

+ 1 - 0
cvdatasets/annotations/__init__.py

@@ -43,5 +43,6 @@ class AnnotationType(BaseChoiceType):
 	INAT20 = INAT20_Annotations
 	INAT20_TEST = partial(INAT20_Annotations)
 	INAT20_IN_CLASS = partial(INAT20_Annotations)
+	INAT20_OUT_CLASS = partial(INAT20_Annotations)
 
 	Default = CUB200

+ 12 - 2
cvdatasets/utils/dataset.py

@@ -1,5 +1,6 @@
 import logging
 import numpy as np
+import warnings
 
 def _format_kwargs(kwargs):
 	return " ".join([f"{key}={value}" for key, value in kwargs.items()])
@@ -15,8 +16,17 @@ def new_iterator(data, n_jobs, batch_size, repeat=True, shuffle=True, n_prefetch
 		except ImportError:
 			pass
 
-		input_shape = getattr(data, "_size", (512, 512))
-		shared_mem_shape = (batch_size, 3) + tuple(input_shape)
+		input_shape = getattr(data, "size", (512, 512))
+		if isinstance(input_shape, int):
+			input_shape = (input_shape, input_shape)
+		elif not isinstance(input_shape, tuple):
+			try:
+				input_shape = tuple(input_shape)
+			except TypeError as e:
+				warnings.warn(f"Could not parse input_shape: \"{input_shape}\". Falling back to a default value of (512, 512)")
+				input_shape = (512, 512)
+
+		shared_mem_shape = (batch_size, 3) + input_shape
 		shared_mem = np.zeros(shared_mem_shape, dtype=np.float32).nbytes
 		logging.info(f"Using {shared_mem / 1024**2: .3f} MiB of shared memory")
 

+ 8 - 0
scripts/info_files/info.yml

@@ -109,6 +109,10 @@ DATASETS:
     <<: *inat20
     annotations: "2020/IN_CLASS"
 
+  INAT20_OUT_CLASS:         &inat20_out_class
+    <<: *inat20
+    annotations: "2020/OUT_CLASS"
+
   INAT20_TEST:    &inat20_test
     <<: *inat20
     annotations: "2020/TEST"
@@ -279,6 +283,10 @@ PARTS:
     <<: *inat20_in_class
     <<: *parts_global
 
+  INAT20_OUT_CLASS_GLOBAL:
+    <<: *inat20_out_class
+    <<: *parts_global
+
   HED_GLOBAL:
     <<: *hed
     <<: *parts_global