浏览代码

fixed shared mem pre-computation for multi process iterator

Dimitri Korsch 4 年之前
父节点
当前提交
9acaa92324
共有 2 个文件被更改,包括 3 次插入3 次删除
  1. 1 1
      cvdatasets/_version.py
  2. 2 2
      cvdatasets/utils/dataset.py

+ 1 - 1
cvdatasets/_version.py

@@ -1 +1 @@
-__version__ = "0.8.2"
+__version__ = "0.8.3"

+ 2 - 2
cvdatasets/utils/dataset.py

@@ -32,9 +32,9 @@ def new_iterator(data, n_jobs, batch_size, repeat=True, shuffle=True, n_prefetch
 				warnings.warn(f"Could not parse input_shape: \"{input_shape}\". Falling back to a default value of (512, 512)")
 				input_shape = (512, 512)
 
-		shared_mem_shape = (batch_size, 3) + input_shape
+		shared_mem_shape = (3,) + input_shape
 		shared_mem = np.zeros(shared_mem_shape, dtype=np.float32).nbytes
-		logging.info(f"Using {shared_mem / 1024**2: .3f} MiB of shared memory")
+		logging.info(f"Using {batch_size * shared_mem / 1024**2: .3f} MiB of shared memory")
 
 		it_kwargs = dict(
 			n_processes=n_jobs,