|
@@ -9,13 +9,23 @@ def new_iterator(data, n_jobs, batch_size, repeat=True, shuffle=True, n_prefetch
|
|
|
|
|
|
if n_jobs > 0:
|
|
|
it_cls = MultiprocessIterator
|
|
|
+ try:
|
|
|
+ import cv2
|
|
|
+ cv2.setNumThreads(0)
|
|
|
+ except ImportError:
|
|
|
+ pass
|
|
|
+
|
|
|
+ input_shape = getattr(data, "_size", (512, 512))
|
|
|
+ shared_mem_shape = (batch_size, 3) + tuple(input_shape)
|
|
|
+ shared_mem = np.zeros(shared_mem_shape, dtype=np.float32).nbytes
|
|
|
+ logging.info(f"Using {shared_mem / 1024**2: .3f} MiB of shared memory")
|
|
|
+
|
|
|
it_kwargs = dict(
|
|
|
n_processes=n_jobs,
|
|
|
n_prefetch=n_prefetch,
|
|
|
batch_size=batch_size,
|
|
|
repeat=repeat, shuffle=shuffle,
|
|
|
- shared_mem=np.zeros((32,3,1024,1024), dtype=np.float32).nbytes)
|
|
|
-
|
|
|
+ shared_mem=shared_mem)
|
|
|
else:
|
|
|
it_cls = SerialIterator
|
|
|
it_kwargs = dict(
|