factory.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. from cvargparse.utils import factory
  2. from cvargparse.argument import Argument as Arg
  3. from abc import ABC
  4. class BaseFactory(ABC):
  5. '''
  6. '''
  7. def __init__(self, initial=None):
  8. super(BaseFactory, self).__init__()
  9. self.args = initial or []
  10. @factory
  11. def add(self, *args, **kwargs):
  12. self.args.append(Arg(*args, **kwargs))
  13. def get(self):
  14. return self.args
  15. class ArgFactory(BaseFactory):
  16. '''
  17. '''
  18. @factory
  19. def batch_size(self):
  20. self.add('--batch_size', '-b', type=int, default=32, help='batch size')
  21. @factory
  22. def epochs(self):
  23. self.add('--epochs', '-e', type=int, default=30, help='number of epochs')
  24. @factory
  25. def debug(self):
  26. self.add('--debug', action='store_true', help='enable chainer debug mode')
  27. @factory
  28. def seed(self):
  29. self.add('--seed', type=int, default=None, help='random seed')
  30. @factory
  31. def weight_decay(self, default=5e-3):
  32. self.add('--decay', type=float, default=default, help='weight decay')
  33. @factory
  34. def learning_rate(self, lr=1e-2, lrs=10, lrd=1e-1, lrt=1e-6):
  35. self.add('--learning_rate', '-lr', type=float, default=lr, help='learning rate')
  36. self.add('--lr_shift', '-lrs', type=int, default=lrs, help='learning rate shift interval (in epochs)')
  37. self.add('--lr_decrease_rate', '-lrd', type=float, default=lrd, help='learning rate decrease')
  38. self.add('--lr_target', '-lrt', type=float, default=lrt, help='learning rate target')