|
@@ -40,7 +40,6 @@ def main(args):
|
|
|
|
|
|
parser.init_logger()
|
|
|
main(parser.parse_args())
|
|
|
-
|
|
|
```
|
|
|
|
|
|
This script can be called as following:
|
|
@@ -52,7 +51,93 @@ python script.py path/to/data path/to/labels path/to/model \
|
|
|
--batch_size 32 \
|
|
|
--epochs 90 \
|
|
|
--loglevel DEBUG \
|
|
|
- --logfile path/to/logs
|
|
|
+ --logfile path/to/logs
|
|
|
+```
|
|
|
+
|
|
|
+## Main Features
|
|
|
+
|
|
|
+### ArgFactory
|
|
|
+
|
|
|
+```python
|
|
|
+from cvargparse import GPUParser, ArgFactory, Arg
|
|
|
+
|
|
|
+factory = ArgFactory([
|
|
|
+ Arg("data", type=str),
|
|
|
+ Arg("labels", type=str),
|
|
|
+ Arg("model_weights", type=str),
|
|
|
+])
|
|
|
+
|
|
|
+facotry.epochs()
|
|
|
+facotry.batch_size()
|
|
|
+factory.weight_decay(5e-3)
|
|
|
+factory.learning_rate(lr=1e-3)
|
|
|
+factory.debug().seed()
|
|
|
+
|
|
|
+parser = GPUParser(factory)
|
|
|
+args = parser.parse_args()
|
|
|
+```
|
|
|
+
|
|
|
+### Argument Choices
|
|
|
+```python
|
|
|
+import logging
|
|
|
+from cvargparse.utils import BaseChoiceType
|
|
|
+from dlframework.models import VGG19, ResNet, InceptionV3
|
|
|
+from dlframework.optimizers import Adam, RMSProp, MomentumSGD
|
|
|
+
|
|
|
+
|
|
|
+class ModelTypes(BaseChoiceType):
|
|
|
+ Default = ResNet
|
|
|
+ Resnet = ResNet
|
|
|
+ VGG = VGG19
|
|
|
+ Inception = InceptionV3
|
|
|
+
|
|
|
+
|
|
|
+class OptimizerTypes(BaseChoiceType):
|
|
|
+ Default = Adam
|
|
|
+ adam = Adam
|
|
|
+ rms = RMSProp
|
|
|
+ sgd = MomentumSGD
|
|
|
+
|
|
|
+
|
|
|
+def main(args):
|
|
|
+
|
|
|
+ model_type = ModelType.get(args.model_type)
|
|
|
+ logging.info("Creating \"{}\" model".format(model_type.name))
|
|
|
+ model_cls = model_type.value
|
|
|
+ model = model_cls(args.model_weights)
|
|
|
+
|
|
|
+ opt_type = OptimizerType.get(args.model_type)
|
|
|
+ logging.info("Using \"{}\" optimizer".format(opt_type.name))
|
|
|
+ opt_cls = opt_type.value
|
|
|
+ opt = opt_cls(args.learning_rate, model)
|
|
|
+
|
|
|
+
|
|
|
+ # further training / optimization code
|
|
|
+
|
|
|
+factory = ArgFactory([
|
|
|
+ Arg("data", type=str),
|
|
|
+ Arg("labels", type=str),
|
|
|
+ Arg("model_weights", type=str),
|
|
|
|
|
|
+ ModelTypes.as_arg(name="model_type", short_name="mt", help_text="Model type selection"),
|
|
|
+ OptimizerTypes.as_arg(name="optimizer", short_name="opt", help_text="Optimizer selection"),
|
|
|
+])
|
|
|
+
|
|
|
+facotry.epochs()
|
|
|
+facotry.batch_size()
|
|
|
+factory.weight_decay(5e-3)
|
|
|
+factory.learning_rate(lr=1e-3)
|
|
|
+factory.debug().seed()
|
|
|
+
|
|
|
+parser = GPUParser(factory)
|
|
|
+parser.init_logger()
|
|
|
+
|
|
|
+main(parser.parse_args())
|
|
|
```
|
|
|
|
|
|
+```bash
|
|
|
+python script.py path/to/data path/to/labels path/to/model \
|
|
|
+ --model_type resnet
|
|
|
+ --optimizer adam
|
|
|
+ ...
|
|
|
+```
|