main.py 916 B

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. #!/usr/bin/env python
  2. if __name__ != '__main__': raise Exception("Do not import me!")
  3. import socket
  4. if socket.gethostname() != "sigma25":
  5. import matplotlib
  6. matplotlib.use('Agg')
  7. import chainer
  8. import logging
  9. from chainer.training.updaters import StandardUpdater
  10. from finetune.finetuner import DefaultFinetuner
  11. from utils import parser
  12. from core import classifier, dataset, trainer
  13. def main(args):
  14. if args.debug:
  15. chainer.set_debug(args.debug)
  16. logging.warning("DEBUG MODE ENABLED!")
  17. tuner = DefaultFinetuner(
  18. args,
  19. classifier_cls=classifier.FVEClassifier,
  20. classifier_kwargs=dict(
  21. n_comps=args.n_components,
  22. fv_insize=args.fv_insize,
  23. alpha=args.alpha,
  24. ),
  25. model_kwargs=dict(
  26. pooling=args.pooling,
  27. ),
  28. dataset_cls=dataset.PartsDataset,
  29. updater_cls=StandardUpdater,
  30. updater_kwargs={},
  31. )
  32. tuner.run(trainer_cls=trainer.PartsTrainer, opts=args)
  33. main(parser.parse_args())