main.py 818 B

12345678910111213141516171819202122232425262728293031323334353637383940
  1. #!/usr/bin/env python
  2. if __name__ != '__main__': raise Exception("Do not import me!")
  3. import socket
  4. if socket.gethostname() != "sigma25":
  5. import matplotlib
  6. matplotlib.use('Agg')
  7. import chainer
  8. import logging
  9. from chainer.training.updaters import StandardUpdater
  10. from cvfinetune.dataset import BaseDataset
  11. from cvfinetune.finetuner import FinetunerFactory
  12. from cvfinetune.training.trainer import Trainer
  13. from cvmodelz.classifiers import Classifier
  14. from utils import parser
  15. def main(args):
  16. if args.debug:
  17. chainer.set_debug(args.debug)
  18. logging.warning("DEBUG MODE ENABLED!")
  19. factory = FinetunerFactory.new(mpi=False)
  20. tuner = factory(args,
  21. classifier_cls=Classifier,
  22. dataset_cls=BaseDataset,
  23. updater_cls=StandardUpdater,
  24. )
  25. tuner.run(trainer_cls=Trainer, opts=args)
  26. main(parser.parse_args())