|
@@ -2,11 +2,12 @@ import abc
|
|
|
import typing as T
|
|
|
|
|
|
from chainer_addons.links import PoolingType
|
|
|
-from chainer_addons.models import PrepareType
|
|
|
from cvargparse import Arg
|
|
|
from cvargparse import BaseParser
|
|
|
+
|
|
|
from cvfinetune.parser.utils import parser_extender
|
|
|
from cvmodelz.models import ModelFactory
|
|
|
+from cvmodelz.models import PrepareType
|
|
|
|
|
|
@parser_extender
|
|
|
def add_model_args(parser: BaseParser, *,
|
|
@@ -27,11 +28,13 @@ def add_model_args(parser: BaseParser, *,
|
|
|
choices=["imagenet", "inat"],
|
|
|
help="type of model pre-training"),
|
|
|
|
|
|
- Arg("--input_size", type=int, nargs="+", default=0,
|
|
|
- help="overrides default input size of the model, if greater than 0"),
|
|
|
+ Arg("--input_size", default="infer",
|
|
|
+ help="overrides default input size of the model, if greater than 0. " \
|
|
|
+ "Set to \"infer\" to guess the input size from the model type."),
|
|
|
|
|
|
- Arg("--part_input_size", type=int, nargs="+", default=0,
|
|
|
- help="overrides default input part size of the model, if greater than 0"),
|
|
|
+ Arg("--part_input_size", default="infer",
|
|
|
+ help="overrides default input part size of the model, if greater than 0. " \
|
|
|
+ "Set to \"infer\" to guess the input size from the model type."),
|
|
|
|
|
|
PrepareType.as_arg("prepare_type",
|
|
|
help_text="type of image preprocessing"),
|
|
@@ -39,10 +42,10 @@ def add_model_args(parser: BaseParser, *,
|
|
|
PoolingType.as_arg("pooling",
|
|
|
help_text="type of pre-classification pooling"),
|
|
|
|
|
|
- Arg("--load", type=str,
|
|
|
+ Arg("--load",
|
|
|
help="ignore weights and load already fine-tuned model (classifier will NOT be re-initialized and number of classes will be unchanged)"),
|
|
|
|
|
|
- Arg("--weights", type=str,
|
|
|
+ Arg("--weights",
|
|
|
help="ignore default weights and load already pre-trained model (classifier will be re-initialized and number of classes will be changed)"),
|
|
|
|
|
|
Arg("--headless", action="store_true",
|
|
@@ -51,7 +54,7 @@ def add_model_args(parser: BaseParser, *,
|
|
|
Arg("--load_strict", action="store_true",
|
|
|
help="load weights in a strict mode"),
|
|
|
|
|
|
- Arg("--load_path", type=str, default="",
|
|
|
+ Arg("--load_path", default="",
|
|
|
help="load path within the weights archive"),
|
|
|
]
|
|
|
|