|
@@ -23,6 +23,8 @@ def new_iterator(data, n_jobs, batch_size, repeat=True, shuffle=True, n_prefetch
|
|
|
pass
|
|
|
|
|
|
input_shape = getattr(data, "size", (512, 512))
|
|
|
+ n_parts = getattr(data, "n_parts", 1)
|
|
|
+
|
|
|
if isinstance(input_shape, int):
|
|
|
input_shape = (input_shape, input_shape)
|
|
|
elif not isinstance(input_shape, tuple):
|
|
@@ -33,7 +35,7 @@ def new_iterator(data, n_jobs, batch_size, repeat=True, shuffle=True, n_prefetch
|
|
|
input_shape = (512, 512)
|
|
|
|
|
|
shared_mem_shape = (3,) + input_shape
|
|
|
- shared_mem = np.zeros(shared_mem_shape, dtype=np.float32).nbytes
|
|
|
+ shared_mem = (n_parts+1) * np.zeros(shared_mem_shape, dtype=np.float32).nbytes
|
|
|
logging.info(f"Using {batch_size * shared_mem / 1024**2: .3f} MiB of shared memory")
|
|
|
|
|
|
it_kwargs = dict(
|