Browse Source

some updates in MultiprocessIterator creation

Dimitri Korsch 5 years ago
parent
commit
67e49ec726
2 changed files with 13 additions and 2 deletions
  1. 1 0
      cvdatasets/annotations/__init__.py
  2. 12 2
      cvdatasets/utils/dataset.py

+ 1 - 0
cvdatasets/annotations/__init__.py

@@ -42,5 +42,6 @@ class AnnotationType(BaseChoiceType):
 
 	INAT20 = INAT20_Annotations
 	INAT20_TEST = partial(INAT20_Annotations)
+	INAT20_IN_CLASS = partial(INAT20_Annotations)
 
 	Default = CUB200

+ 12 - 2
cvdatasets/utils/dataset.py

@@ -9,13 +9,23 @@ def new_iterator(data, n_jobs, batch_size, repeat=True, shuffle=True, n_prefetch
 
 	if n_jobs > 0:
 		it_cls = MultiprocessIterator
+		try:
+			import cv2
+			cv2.setNumThreads(0)
+		except ImportError:
+			pass
+
+		input_shape = getattr(data, "_size", (512, 512))
+		shared_mem_shape = (batch_size, 3) + tuple(input_shape)
+		shared_mem = np.zeros(shared_mem_shape, dtype=np.float32).nbytes
+		logging.info(f"Using {shared_mem / 1024**2: .3f} MiB of shared memory")
+
 		it_kwargs = dict(
 			n_processes=n_jobs,
 			n_prefetch=n_prefetch,
 			batch_size=batch_size,
 			repeat=repeat, shuffle=shuffle,
-			shared_mem=np.zeros((32,3,1024,1024), dtype=np.float32).nbytes)
-
+			shared_mem=shared_mem)
 	else:
 		it_cls = SerialIterator
 		it_kwargs = dict(