浏览代码

first commit

Dimitri Korsch 6 年之前
父节点
当前提交
12f2a4b513
共有 4 个文件被更改,包括 178 次插入0 次删除
  1. 11 0
      Makefile
  2. 135 0
      cvargparse/__init__.py
  3. 0 0
      requirements.txt
  4. 32 0
      setup.py

+ 11 - 0
Makefile

@@ -0,0 +1,11 @@
+install:
+	python setup.py install
+
+build:
+	python setup.py build
+
+deploy:
+	python setup.py sdist upload -r pypi
+
+get_version:
+	@python -c "import cvargparse; print('v{}'.format(cvargparse.__version__))"

+ 135 - 0
cvargparse/__init__.py

@@ -0,0 +1,135 @@
+import argparse, logging
+
+class Arg(object):
+	def __init__(self, *args, **kw):
+		super(Arg, self).__init__()
+		self.args = args
+		self.kw = kw
+
+class FileArg(Arg):
+	def __init__(self, *args, **kw):
+		super(FileArg, self).__init__(*args, **kw)
+
+	@classmethod
+	def mode(cls, file_mode, encoding=None):
+		def wrapper(*args, **kw):
+			obj = cls(*args, **kw)
+			obj.kw["type"] = argparse.FileType(file_mode, encoding=encoding)
+			return obj
+		return wrapper
+
+class BaseParser(argparse.ArgumentParser):
+	def __init__(self, arglist=[], nologging=False, sysargs=None, *args, **kw):
+		super(BaseParser, self).__init__(*args, **kw)
+		self.__nologging = nologging
+		self.__sysargs = sysargs
+		if isinstance(arglist, ArgFactory):
+			arglist = arglist.get()
+
+		for arg in arglist:
+			if isinstance(arg, Arg):
+				self.add_argument(*arg.args, **arg.kw)
+			else:
+				self.add_argument(*arg[0], **arg[1])
+
+
+		if not self.has_logging: return
+
+		self.add_argument(
+			'--logfile', type=str, default='',
+			help='file for logging output')
+
+		self.add_argument(
+			'--loglevel', type=str, default='INFO',
+			help='logging level. see logging module for more information')
+
+		self.__args = None
+
+
+	@property
+	def args(self):
+		if self.__args is None:
+			self.__args = self.parse_args(self.__sysargs)
+
+		return self.__args
+
+
+	@property
+	def has_logging(self):
+		return not self.__nologging
+
+	def init_logger(self, simple=False):
+		if not self.has_logging: return
+		fmt = '%(message)s' if simple else '%(levelname)s - [%(asctime)s] %(filename)s:%(lineno)d [%(funcName)s]: %(message)s'
+		logging.basicConfig(
+			format=fmt,
+			level=getattr(logging, self.args.loglevel.upper(), logging.DEBUG),
+			filename=self.args.logfile or None,
+			filemode="w")
+
+
+class GPUParser(BaseParser):
+	def __init__(self, *args, **kw):
+		super(GPUParser, self).__init__(*args, **kw)
+		self.add_argument(
+			"--gpu", "-g", type=int, nargs="+", default=[-1],
+			help="which GPU to use. select -1 for CPU only")
+
+
+def factory(func):
+	"""
+		Returns 'self' at the end
+	"""
+	def inner(self, *args, **kw):
+		func(self, *args, **kw)
+		return self
+	return inner
+
+class ArgFactory(object):
+	def __init__(self, initial=[]):
+		super(ArgFactory, self).__init__()
+		self.args = initial
+
+	def get(self):
+		return self.args
+
+	@factory
+	def batch_size(self):
+		self.args.append(
+			Arg("--batch_size", "-b", type=int, default=32, help="batch size")
+		)
+
+	@factory
+	def epochs(self):
+		self.args.append(
+			Arg("--epochs", "-e", type=int, default=30, help="number of epochs"),
+		)
+
+	@factory
+	def debug(self):
+		self.args.append(
+			Arg("--debug", action="store_true", help="enable chainer debug mode"),
+		)
+
+	@factory
+	def seed(self):
+		self.args.append(
+			Arg("--seed", type=int, default=None, help="random seed"),
+		)
+
+	@factory
+	def weight_decay(self, default=5e-3):
+		self.args.append(
+			Arg("--decay", type=float, default=default, help="weight decay"),
+		)
+
+	@factory
+	def learning_rate(self, lr=1e-2, lrs=10, lrd=1e-1, lrt=1e-6):
+		self.args.extend([
+			Arg("--learning_rate", "-lr", type=float, default=lr, help="learning rate"),
+			Arg("--lr_shift", "-lrs", type=int, default=lrs, help="learning rate shift interval (in epochs)"),
+			Arg("--lr_decrease_rate", "-lrd", type=float, default=lrd, help="learning rate decrease"),
+			Arg("--lr_target", "-lrt", type=float, default=lrt, help="learning rate target"),
+		])
+
+__version__ = "0.1.1"

+ 0 - 0
requirements.txt


+ 32 - 0
setup.py

@@ -0,0 +1,32 @@
+#!/usr/bin/env python
+
+import os
+import pkg_resources
+import sys
+
+from setuptools import setup, find_packages
+
+try: # for pip >= 10
+	from pip._internal.req import parse_requirements
+except ImportError: # for pip <= 9.0.3
+	from pip.req import parse_requirements
+
+import cvargparse
+# install_requires = [line.strip() for line in open("requirements.txt").readlines()]
+
+setup(
+	name='cvargparse',
+	version=cvargparse.__version__,
+	description='simple argparse wrapper with some syntactic sugar',
+	author='Dimitri Korsch',
+	author_email='korschdima@gmail.com',
+	license='MIT License',
+	packages=find_packages(),
+	zip_safe=False,
+	setup_requires=[],
+	# no requirements yet
+	# install_requires=install_requires,
+	package_data={'': ['requirements.txt']},
+	data_files=[('.',['requirements.txt'])],
+	include_package_data=True,
+)