浏览代码

refactored logging a bit and fixed shared memory usage

Dimitri Korsch 5 年之前
父节点
当前提交
055752a058
共有 1 个文件被更改,包括 13 次插入4 次删除
  1. 13 4
      cvdatasets/utils/dataset.py

+ 13 - 4
cvdatasets/utils/dataset.py

@@ -1,20 +1,29 @@
 import logging
 import numpy as np
 
+def _format_kwargs(kwargs):
+	return " ".join([f"{key}={value}" for key, value in kwargs.items()])
+
 def new_iterator(data, n_jobs, batch_size, repeat=True, shuffle=True, n_prefetch=2):
 	from chainer.iterators import SerialIterator, MultiprocessIterator
 
 	if n_jobs > 0:
-		it = MultiprocessIterator(data,
+		it_cls = MultiprocessIterator
+		it_kwargs = dict(
 			n_processes=n_jobs,
 			n_prefetch=n_prefetch,
 			batch_size=batch_size,
 			repeat=repeat, shuffle=shuffle,
-			shared_mem=np.zeros((3,1024,1024), dtype=np.float32).nbytes)
+			shared_mem=np.zeros((32,3,1024,1024), dtype=np.float32).nbytes)
+
 	else:
-		it = SerialIterator(data,
+		it_cls = SerialIterator
+		it_kwargs = dict(
 			batch_size=batch_size,
 			repeat=repeat, shuffle=shuffle)
-	logging.info("Using {it.__class__.__name__} with batch size {it.batch_size}".format(it=it))
+
+	it = it_cls(data, **it_kwargs)
 	n_batches = int(np.ceil(len(data) / it.batch_size))
+	logging.info(f"Using {it_cls.__name__} with {n_batches:,d} batches per epoch and kwargs: {_format_kwargs(it_kwargs)}")
+
 	return it, n_batches