main.py 934 B

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. #!/usr/bin/env python
  2. if __name__ != '__main__': raise Exception("Do not import me!")
  3. import socket
  4. if socket.gethostname() != "sigma25":
  5. import matplotlib
  6. matplotlib.use('Agg')
  7. import chainer
  8. import logging
  9. from chainer.training.updaters import StandardUpdater
  10. from chainer_addons.models.classifier import Classifier
  11. from finetune.finetuner import DefaultFinetuner
  12. from finetune.training.trainer import Trainer
  13. from finetune.dataset import BaseDataset
  14. from finetune.classifier import Classifier
  15. from utils import parser
  16. def main(args):
  17. if args.debug:
  18. chainer.set_debug(args.debug)
  19. logging.warning("DEBUG MODE ENABLED!")
  20. tuner = DefaultFinetuner(
  21. args,
  22. classifier_cls=Classifier,
  23. classifier_kwargs={},
  24. model_kwargs=dict(
  25. pooling=args.pooling,
  26. ),
  27. dataset_cls=BaseDataset,
  28. updater_cls=StandardUpdater,
  29. updater_kwargs={},
  30. )
  31. tuner.run(trainer_cls=Trainer, opts=args)
  32. main(parser.parse_args())