|
@@ -1,5 +1,6 @@
|
|
|
import logging
|
|
|
import numpy as np
|
|
|
+import warnings
|
|
|
|
|
|
def _format_kwargs(kwargs):
|
|
|
return " ".join([f"{key}={value}" for key, value in kwargs.items()])
|
|
@@ -15,8 +16,17 @@ def new_iterator(data, n_jobs, batch_size, repeat=True, shuffle=True, n_prefetch
|
|
|
except ImportError:
|
|
|
pass
|
|
|
|
|
|
- input_shape = getattr(data, "_size", (512, 512))
|
|
|
- shared_mem_shape = (batch_size, 3) + tuple(input_shape)
|
|
|
+ input_shape = getattr(data, "size", (512, 512))
|
|
|
+ if isinstance(input_shape, int):
|
|
|
+ input_shape = (input_shape, input_shape)
|
|
|
+ elif not isinstance(input_shape, tuple):
|
|
|
+ try:
|
|
|
+ input_shape = tuple(input_shape)
|
|
|
+ except TypeError as e:
|
|
|
+ warnings.warn(f"Could not parse input_shape: \"{input_shape}\". Falling back to a default value of (512, 512)")
|
|
|
+ input_shape = (512, 512)
|
|
|
+
|
|
|
+ shared_mem_shape = (batch_size, 3) + input_shape
|
|
|
shared_mem = np.zeros(shared_mem_shape, dtype=np.float32).nbytes
|
|
|
logging.info(f"Using {shared_mem / 1024**2: .3f} MiB of shared memory")
|
|
|
|