|
@@ -63,13 +63,13 @@ class ModelFactory(abc.ABC):
|
|
|
if model_cls in cls.supported["chainer"]:
|
|
|
if "pretrained_model" not in kwargs:
|
|
|
kwargs["pretrained_model"] = None
|
|
|
- kwargs.pop("input_size")
|
|
|
+ kwargs.pop("input_size", None)
|
|
|
|
|
|
elif model_cls in cls.supported["chainercv2"]:
|
|
|
if "pretrained" not in kwargs:
|
|
|
kwargs["pretrained"] = False
|
|
|
|
|
|
- input_size = kwargs.pop("input_size")
|
|
|
+ input_size = kwargs.pop("input_size", None)
|
|
|
return ModelWrapper(model_cls(*args, **kwargs), input_size=input_size)
|
|
|
|
|
|
return model_cls(*args, **kwargs)
|