|
@@ -1,20 +1,29 @@
|
|
import logging
|
|
import logging
|
|
import numpy as np
|
|
import numpy as np
|
|
|
|
|
|
|
|
+def _format_kwargs(kwargs):
|
|
|
|
+ return " ".join([f"{key}={value}" for key, value in kwargs.items()])
|
|
|
|
+
|
|
def new_iterator(data, n_jobs, batch_size, repeat=True, shuffle=True, n_prefetch=2):
|
|
def new_iterator(data, n_jobs, batch_size, repeat=True, shuffle=True, n_prefetch=2):
|
|
from chainer.iterators import SerialIterator, MultiprocessIterator
|
|
from chainer.iterators import SerialIterator, MultiprocessIterator
|
|
|
|
|
|
if n_jobs > 0:
|
|
if n_jobs > 0:
|
|
- it = MultiprocessIterator(data,
|
|
|
|
|
|
+ it_cls = MultiprocessIterator
|
|
|
|
+ it_kwargs = dict(
|
|
n_processes=n_jobs,
|
|
n_processes=n_jobs,
|
|
n_prefetch=n_prefetch,
|
|
n_prefetch=n_prefetch,
|
|
batch_size=batch_size,
|
|
batch_size=batch_size,
|
|
repeat=repeat, shuffle=shuffle,
|
|
repeat=repeat, shuffle=shuffle,
|
|
- shared_mem=np.zeros((3,1024,1024), dtype=np.float32).nbytes)
|
|
|
|
|
|
+ shared_mem=np.zeros((32,3,1024,1024), dtype=np.float32).nbytes)
|
|
|
|
+
|
|
else:
|
|
else:
|
|
- it = SerialIterator(data,
|
|
|
|
|
|
+ it_cls = SerialIterator
|
|
|
|
+ it_kwargs = dict(
|
|
batch_size=batch_size,
|
|
batch_size=batch_size,
|
|
repeat=repeat, shuffle=shuffle)
|
|
repeat=repeat, shuffle=shuffle)
|
|
- logging.info("Using {it.__class__.__name__} with batch size {it.batch_size}".format(it=it))
|
|
|
|
|
|
+
|
|
|
|
+ it = it_cls(data, **it_kwargs)
|
|
n_batches = int(np.ceil(len(data) / it.batch_size))
|
|
n_batches = int(np.ceil(len(data) / it.batch_size))
|
|
|
|
+ logging.info(f"Using {it_cls.__name__} with {n_batches:,d} batches per epoch and kwargs: {_format_kwargs(it_kwargs)}")
|
|
|
|
+
|
|
return it, n_batches
|
|
return it, n_batches
|