Browse Source

Refactored everything

Christoph Theiß 6 years ago
parent
commit
512ebb3afb
6 changed files with 157 additions and 134 deletions
  1. 9 0
      .editorconfig
  2. 5 134
      cvargparse/__init__.py
  3. 18 0
      cvargparse/arguments.py
  4. 57 0
      cvargparse/factory.py
  5. 60 0
      cvargparse/parser.py
  6. 8 0
      cvargparse/utils.py

+ 9 - 0
.editorconfig

@@ -0,0 +1,9 @@
+# editorconfig.org
+root = true
+
+[*]
+indent_style = tab
+end_of_line = lf
+charset = utf-8
+trim_trailing_whitespace = true
+insert_final_newline = true

+ 5 - 134
cvargparse/__init__.py

@@ -1,135 +1,6 @@
-import argparse, logging
+__version__ = "0.1.2"
 
-class Arg(object):
-	def __init__(self, *args, **kw):
-		super(Arg, self).__init__()
-		self.args = args
-		self.kw = kw
-
-class FileArg(Arg):
-	def __init__(self, *args, **kw):
-		super(FileArg, self).__init__(*args, **kw)
-
-	@classmethod
-	def mode(cls, file_mode, encoding=None):
-		def wrapper(*args, **kw):
-			obj = cls(*args, **kw)
-			obj.kw["type"] = argparse.FileType(file_mode, encoding=encoding)
-			return obj
-		return wrapper
-
-class BaseParser(argparse.ArgumentParser):
-	def __init__(self, arglist=[], nologging=False, sysargs=None, *args, **kw):
-		super(BaseParser, self).__init__(*args, **kw)
-		self.__nologging = nologging
-		self.__sysargs = sysargs
-		if isinstance(arglist, ArgFactory):
-			arglist = arglist.get()
-
-		for arg in arglist:
-			if isinstance(arg, Arg):
-				self.add_argument(*arg.args, **arg.kw)
-			else:
-				self.add_argument(*arg[0], **arg[1])
-
-
-		if not self.has_logging: return
-
-		self.add_argument(
-			'--logfile', type=str, default='',
-			help='file for logging output')
-
-		self.add_argument(
-			'--loglevel', type=str, default='INFO',
-			help='logging level. see logging module for more information')
-
-		self.__args = None
-
-
-	@property
-	def args(self):
-		if self.__args is None:
-			self.__args = self.parse_args(self.__sysargs)
-
-		return self.__args
-
-
-	@property
-	def has_logging(self):
-		return not self.__nologging
-
-	def init_logger(self, simple=False):
-		if not self.has_logging: return
-		fmt = '%(message)s' if simple else '%(levelname)s - [%(asctime)s] %(filename)s:%(lineno)d [%(funcName)s]: %(message)s'
-		logging.basicConfig(
-			format=fmt,
-			level=getattr(logging, self.args.loglevel.upper(), logging.DEBUG),
-			filename=self.args.logfile or None,
-			filemode="w")
-
-
-class GPUParser(BaseParser):
-	def __init__(self, *args, **kw):
-		super(GPUParser, self).__init__(*args, **kw)
-		self.add_argument(
-			"--gpu", "-g", type=int, nargs="+", default=[-1],
-			help="which GPU to use. select -1 for CPU only")
-
-
-def factory(func):
-	"""
-		Returns 'self' at the end
-	"""
-	def inner(self, *args, **kw):
-		func(self, *args, **kw)
-		return self
-	return inner
-
-class ArgFactory(object):
-	def __init__(self, initial=[]):
-		super(ArgFactory, self).__init__()
-		self.args = initial
-
-	def get(self):
-		return self.args
-
-	@factory
-	def batch_size(self):
-		self.args.append(
-			Arg("--batch_size", "-b", type=int, default=32, help="batch size")
-		)
-
-	@factory
-	def epochs(self):
-		self.args.append(
-			Arg("--epochs", "-e", type=int, default=30, help="number of epochs"),
-		)
-
-	@factory
-	def debug(self):
-		self.args.append(
-			Arg("--debug", action="store_true", help="enable chainer debug mode"),
-		)
-
-	@factory
-	def seed(self):
-		self.args.append(
-			Arg("--seed", type=int, default=None, help="random seed"),
-		)
-
-	@factory
-	def weight_decay(self, default=5e-3):
-		self.args.append(
-			Arg("--decay", type=float, default=default, help="weight decay"),
-		)
-
-	@factory
-	def learning_rate(self, lr=1e-2, lrs=10, lrd=1e-1, lrt=1e-6):
-		self.args.extend([
-			Arg("--learning_rate", "-lr", type=float, default=lr, help="learning rate"),
-			Arg("--lr_shift", "-lrs", type=int, default=lrs, help="learning rate shift interval (in epochs)"),
-			Arg("--lr_decrease_rate", "-lrd", type=float, default=lrd, help="learning rate decrease"),
-			Arg("--lr_target", "-lrt", type=float, default=lrt, help="learning rate target"),
-		])
-
-__version__ = "0.1.1"
+from .arguments import Argument, FileArgument
+from .factory import ArgFactory
+from .parser import BaseParser, GPUParser
+Arg = Argument

+ 18 - 0
cvargparse/arguments.py

@@ -0,0 +1,18 @@
+class Argument(object):
+	def __init__(self, *args, **kw):
+		super(Argument, self).__init__()
+		self.args = args # positional arugments
+		self.kw = kw # keyword arguments
+
+
+class FileArgument(Argument):
+	def __init__(self, *args, **kw):
+		super(FileArgument, self).__init__(*args, **kw)
+
+	@classmethod
+	def mode(cls, file_mode, encoding=None):
+		def wrapper(*args, **kw):
+			obj = cls(*args, **kw)
+			obj.kw["type"] = argparse.FileType(file_mode, encoding=encoding)
+			return obj
+		return wrapper

+ 57 - 0
cvargparse/factory.py

@@ -0,0 +1,57 @@
+from cvargparse.utils import factory 
+from cvargparse.arguments import Argument as Arg
+
+from abc import ABC
+
+class BaseFactory(ABC):
+	'''
+
+	'''
+	def __init__(self, initial=None):
+		super(BaseFactory, self).__init__()
+		self.args = initial or []
+
+	@factory
+	def add(self, *args, **kwargs):
+		self.args.append(Arg(*args, **kwargs))
+
+
+	def get(self):
+		return self.args
+ 
+
+class ArgFactory(BaseFactory):
+	'''
+	
+	'''
+	@factory
+	def batch_size(self):
+		self.add('--batch_size', '-b', type=int, default=32, help='batch size')
+
+
+	@factory
+	def epochs(self):
+		self.add('--epochs', '-e', type=int, default=30, help='number of epochs')
+
+
+	@factory
+	def debug(self):
+		self.add('--debug', action='store_true', help='enable chainer debug mode')
+
+
+	@factory
+	def seed(self):
+		self.add('--seed', type=int, default=None, help='random seed')
+
+
+	@factory
+	def weight_decay(self, default=5e-3):
+		self.add('--decay', type=float, default=default, help='weight decay')
+
+
+	@factory
+	def learning_rate(self, lr=1e-2, lrs=10, lrd=1e-1, lrt=1e-6):
+		self.add('--learning_rate', '-lr', type=float, default=lr, help='learning rate')
+		self.add('--lr_shift', '-lrs', type=int, default=lrs, help='learning rate shift interval (in epochs)')
+		self.add('--lr_decrease_rate', '-lrd', type=float, default=lrd, help='learning rate decrease')
+		self.add('--lr_target', '-lrt', type=float, default=lrt, help='learning rate target')

+ 60 - 0
cvargparse/parser.py

@@ -0,0 +1,60 @@
+import argparse, logging
+
+from arguments import Argument, FileArgument
+
+class BaseParser(argparse.ArgumentParser):
+	def __init__(self, arglist=[], nologging=False, sysargs=None, *args, **kw):
+		super(BaseParser, self).__init__(*args, **kw)
+		self.__nologging = nologging
+		self.__sysargs = sysargs
+		if isinstance(arglist, ArgFactory):
+			arglist = arglist.get()
+
+		for arg in arglist:
+			if isinstance(arg, Arg):
+				self.add_argument(*arg.args, **arg.kw)
+			else:
+				self.add_argument(*arg[0], **arg[1])
+
+
+		if not self.has_logging: return
+
+		self.add_argument(
+			'--logfile', type=str, default='',
+			help='file for logging output')
+
+		self.add_argument(
+			'--loglevel', type=str, default='INFO',
+			help='logging level. see logging module for more information')
+
+		self.__args = None
+
+
+	@property
+	def args(self):
+		if self.__args is None:
+			self.__args = self.parse_args(self.__sysargs)
+
+		return self.__args
+
+
+	@property
+	def has_logging(self):
+		return not self.__nologging
+
+	def init_logger(self, simple=False):
+		if not self.has_logging: return
+		fmt = '%(message)s' if simple else '%(levelname)s - [%(asctime)s] %(filename)s:%(lineno)d [%(funcName)s]: %(message)s'
+		logging.basicConfig(
+			format=fmt,
+			level=getattr(logging, self.args.loglevel.upper(), logging.DEBUG),
+			filename=self.args.logfile or None,
+			filemode="w")
+
+
+class GPUParser(BaseParser):
+	def __init__(self, *args, **kw):
+		super(GPUParser, self).__init__(*args, **kw)
+		self.add_argument(
+			"--gpu", "-g", type=int, nargs="+", default=[-1],
+			help="which GPU to use. select -1 for CPU only")

+ 8 - 0
cvargparse/utils.py

@@ -0,0 +1,8 @@
+def factory(func):
+	"""
+		Returns 'self' at the end
+	"""
+	def inner(self, *args, **kw):
+		func(self, *args, **kw)
+		return self
+	return inner