瀏覽代碼

added more documentation

Dimitri Korsch 6 年之前
父節點
當前提交
62a1628eb0
共有 1 個文件被更改,包括 87 次插入2 次删除
  1. 87 2
      README.md

+ 87 - 2
README.md

@@ -40,7 +40,6 @@ def main(args):
 
 
 parser.init_logger()
 parser.init_logger()
 main(parser.parse_args())
 main(parser.parse_args())
-
 ```
 ```
 
 
 This script can be called as following:
 This script can be called as following:
@@ -52,7 +51,93 @@ python script.py path/to/data path/to/labels path/to/model \
     --batch_size 32 \
     --batch_size 32 \
     --epochs 90 \
     --epochs 90 \
     --loglevel DEBUG \
     --loglevel DEBUG \
-    --logfile path/to/logs
+    --logfile path/to/logs    
+```
+
+## Main Features
+
+### ArgFactory
+
+```python
+from cvargparse import GPUParser, ArgFactory, Arg
+
+factory = ArgFactory([
+    Arg("data", type=str),
+    Arg("labels", type=str),
+    Arg("model_weights", type=str),
+])
+
+facotry.epochs()
+facotry.batch_size()
+factory.weight_decay(5e-3)
+factory.learning_rate(lr=1e-3)
+factory.debug().seed()
+
+parser = GPUParser(factory)
+args = parser.parse_args()
+```
+
+### Argument Choices
+```python
+import logging
+from cvargparse.utils import BaseChoiceType
+from dlframework.models import VGG19, ResNet, InceptionV3
+from dlframework.optimizers import Adam, RMSProp, MomentumSGD
+
+
+class ModelTypes(BaseChoiceType):
+    Default = ResNet
+    Resnet = ResNet
+    VGG = VGG19
+    Inception = InceptionV3
+
+
+class OptimizerTypes(BaseChoiceType):
+    Default = Adam
+    adam = Adam
+    rms = RMSProp
+    sgd = MomentumSGD
+    
+
+def main(args):
+
+    model_type = ModelType.get(args.model_type)
+    logging.info("Creating \"{}\" model".format(model_type.name))
+    model_cls = model_type.value
+    model = model_cls(args.model_weights)
+    
+    opt_type = OptimizerType.get(args.model_type)
+    logging.info("Using \"{}\" optimizer".format(opt_type.name))
+    opt_cls = opt_type.value
+    opt = opt_cls(args.learning_rate, model)
+    
+    
+    # further training / optimization code
+
+factory = ArgFactory([
+    Arg("data", type=str),
+    Arg("labels", type=str),
+    Arg("model_weights", type=str),
     
     
+    ModelTypes.as_arg(name="model_type", short_name="mt", help_text="Model type selection"),
+    OptimizerTypes.as_arg(name="optimizer", short_name="opt", help_text="Optimizer selection"),
+])
+
+facotry.epochs()
+facotry.batch_size()
+factory.weight_decay(5e-3)
+factory.learning_rate(lr=1e-3)
+factory.debug().seed()
+
+parser = GPUParser(factory)
+parser.init_logger()
+
+main(parser.parse_args())
 ```
 ```
 
 
+```bash
+python script.py path/to/data path/to/labels path/to/model \
+    --model_type resnet
+    --optimizer adam
+    ...
+```