parser.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import os
  2. import logging
  3. import platform
  4. from chainer_addons.training import OptimizerType
  5. from chainer_addons.models import PrepareType
  6. from cvargparse import GPUParser, Arg, ArgFactory
  7. from cvdatasets.utils import read_info_file
  8. DEFAULT_INFO_FILE=os.environ.get("DATA", "/home/korsch/Data/info.yml")
  9. info_file = read_info_file(DEFAULT_INFO_FILE)
  10. def default_factory(extra_list=[]):
  11. return ArgFactory(extra_list + [
  12. Arg("data", default=DEFAULT_INFO_FILE),
  13. Arg("dataset", choices=info_file.DATASETS.keys()),
  14. Arg("parts", choices=info_file.PARTS.keys()),
  15. Arg("--model_type", "-mt",
  16. default="resnet", choices=info_file.MODELS.keys(),
  17. help="type of the model"),
  18. Arg("--input_size", type=int, nargs="+", default=0,
  19. help="overrides default input size of the model, if greater than 0"),
  20. PrepareType.as_arg("prepare_type",
  21. help_text="type of image preprocessing"),
  22. Arg("--load", type=str, help="ignore weights and load already fine-tuned model"),
  23. Arg("--n_jobs", "-j", type=int, default=0,
  24. help="number of loading processes. If 0, then images are loaded in the same process"),
  25. Arg("--warm_up", type=int, help="warm up epochs"),
  26. OptimizerType.as_arg("optimizer", "opt",
  27. help_text="type of the optimizer"),
  28. Arg("--cosine_schedule", action="store_true",
  29. help="enable cosine annealing LR schedule"),
  30. Arg("--l1_loss", action="store_true",
  31. help="(only with \"--only_head\" option!) use L1 Hinge Loss instead of Softmax Cross-Entropy"),
  32. Arg("--from_scratch", action="store_true",
  33. help="Do not load any weights. Train the model from scratch"),
  34. Arg("--label_shift", type=int, default=1,
  35. help="label shift"),
  36. Arg("--swap_channels", action="store_true",
  37. help="preprocessing option: swap channels from RGB to BGR"),
  38. Arg("--label_smoothing", type=float, default=0,
  39. help="Factor for label smoothing"),
  40. Arg("--no_center_crop_on_val", action="store_true",
  41. help="do not center crop imaages in the validation step!"),
  42. Arg("--only_head", action="store_true", help="fine-tune only last layer"),
  43. Arg("--no_progress", action="store_true", help="dont show progress bar"),
  44. Arg("--augment", action="store_true", help="do data augmentation (random croping and random hor. flipping)"),
  45. Arg("--force_load", action="store_true", help="force loading from caffe model"),
  46. Arg("--only_eval", action="store_true", help="evaluate the model only. do not train!"),
  47. Arg("--init_eval", action="store_true", help="evaluate the model before training"),
  48. Arg("--no_snapshot", action="store_true", help="do not save trained model"),
  49. Arg("--output", "-o", type=str, default=".out", help="output folder"),
  50. ])\
  51. .seed()\
  52. .batch_size()\
  53. .epochs()\
  54. .debug()\
  55. .learning_rate(lr=1e-2, lrs=10, lrt=1e-5, lrd=1e-1)\
  56. .weight_decay(default=5e-4)
  57. class FineTuneParser(GPUParser):
  58. def init_logger(self, simple=False, logfile=None):
  59. if not self.has_logging: return
  60. fmt = '{levelname:s} - [{asctime:s}] {filename:s}:{lineno:d} [{funcName:s}]: {message:s}'
  61. handler0 = logging.StreamHandler()
  62. handler0.addFilter(HostnameFilter())
  63. handler0.setFormatter(logging.Formatter("<{hostname:^10s}>: " + fmt, style="{"))
  64. filename = logfile if logfile is not None else f"{platform.node()}.log"
  65. handler1 = logging.FileHandler(filename=filename, mode="w")
  66. handler1.setFormatter(logging.Formatter(fmt, style="{"))
  67. logger = logging.getLogger()
  68. logger.addHandler(handler0)
  69. logger.addHandler(handler1)
  70. logger.setLevel(getattr(logging, self.args.loglevel.upper(), logging.DEBUG))
  71. class HostnameFilter(logging.Filter):
  72. def filter(self, record):
  73. record.hostname = platform.node()
  74. return True