|
@@ -26,17 +26,17 @@ def main(args):
|
|
|
for batch in it:
|
|
|
updater.train(model, batch)
|
|
|
|
|
|
- parser = GPUParser(ArgFactory([
|
|
|
- Arg("data", type=str),
|
|
|
- Arg("labels", type=str),
|
|
|
- Arg("model_weights", type=str),
|
|
|
- ])\
|
|
|
- .epochs()\
|
|
|
- .batch_size()\
|
|
|
- .learning_rate(lr=1e-3)\
|
|
|
- .weight_decay(5e-3)\
|
|
|
- .seed()\
|
|
|
- .debug())
|
|
|
+parser = GPUParser(ArgFactory([
|
|
|
+ Arg("data", type=str),
|
|
|
+ Arg("labels", type=str),
|
|
|
+ Arg("model_weights", type=str),
|
|
|
+])\
|
|
|
+.epochs()\
|
|
|
+.batch_size()\
|
|
|
+.learning_rate(lr=1e-3)\
|
|
|
+.weight_decay(5e-3)\
|
|
|
+.seed()\
|
|
|
+.debug())
|
|
|
|
|
|
parser.init_logger()
|
|
|
main(parser.parse_args())
|