Browse Source

added parts handling in multiprocessing iterator creation (shared_mem size is now adapted to it)

Dimitri Korsch 4 năm trước cách đây
mục cha
commit
42b2a45d70
1 tập tin đã thay đổi với 3 bổ sung1 xóa
  1. 3 1
      cvdatasets/utils/dataset.py

+ 3 - 1
cvdatasets/utils/dataset.py

@@ -23,6 +23,8 @@ def new_iterator(data, n_jobs, batch_size, repeat=True, shuffle=True, n_prefetch
 			pass
 
 		input_shape = getattr(data, "size", (512, 512))
+		n_parts = getattr(data, "n_parts", 1)
+
 		if isinstance(input_shape, int):
 			input_shape = (input_shape, input_shape)
 		elif not isinstance(input_shape, tuple):
@@ -33,7 +35,7 @@ def new_iterator(data, n_jobs, batch_size, repeat=True, shuffle=True, n_prefetch
 				input_shape = (512, 512)
 
 		shared_mem_shape = (3,) + input_shape
-		shared_mem = np.zeros(shared_mem_shape, dtype=np.float32).nbytes
+		shared_mem = (n_parts+1) * np.zeros(shared_mem_shape, dtype=np.float32).nbytes
 		logging.info(f"Using {batch_size * shared_mem / 1024**2: .3f} MiB of shared memory")
 
 		it_kwargs = dict(