Pārlūkot izejas kodu

added support for infering input size from the set model type

Dimitri Korsch 1 gadu atpakaļ
vecāks
revīzija
dad0a4a4d0

+ 1 - 1
cvfinetune/_version.py

@@ -1 +1 @@
-__version__ = "0.11.0"
+__version__ = "0.12.0"

+ 3 - 6
cvfinetune/finetuner/factory.py

@@ -2,9 +2,9 @@ import logging
 import warnings
 
 from cvfinetune import utils
-from cvfinetune.utils import mpi
 from cvfinetune.finetuner.base import DefaultFinetuner
 from cvfinetune.finetuner.mpi import MPIFinetuner
+from cvfinetune.utils import mpi
 
 from cvdatasets.utils import pretty_print_dict
 
@@ -12,12 +12,9 @@ class FinetunerFactory:
 
 	@classmethod
 	def new(cls, *args, **kwargs):
-		raise NotImplementedError("Use simple instance creation instead of {cls.__name__}.new()!")
+		raise NotImplementedError(f"Use simple instance creation instead of {cls.__name__}.new()!")
 
-	def __init__(self, *,
-		         default=DefaultFinetuner,
-		         mpi_tuner=MPIFinetuner,
-		         **kwargs):
+	def __init__(self, *, default=DefaultFinetuner, mpi_tuner=MPIFinetuner, **kwargs):
 		super().__init__()
 
 		if "mpi" in kwargs:

+ 7 - 0
cvfinetune/parser/base.py

@@ -7,6 +7,7 @@ from cvargparse import ArgFactory
 from cvargparse import GPUParser
 from cvargparse.utils import logger_config
 
+from cvfinetune.parser import utils
 from cvfinetune.parser.dataset_args import add_dataset_args
 from cvfinetune.parser.model_args import add_model_args
 from cvfinetune.parser.training_args import add_training_args
@@ -26,6 +27,12 @@ class FineTuneParser(GPUParser):
 		add_model_args(self, model_modules=model_modules)
 		add_training_args(self)
 
+	def parse_args(self, *args, **kwargs):
+		opts = super().parse_args(*args, **kwargs)
+		utils.parse_input_size(opts, attr="input_size")
+		utils.parse_input_size(opts, attr="part_input_size")
+		return opts
+
 
 	def _logging_config(self, simple=False):
 		if not self.has_logging: return

+ 11 - 8
cvfinetune/parser/model_args.py

@@ -2,11 +2,12 @@ import abc
 import typing as T
 
 from chainer_addons.links import PoolingType
-from chainer_addons.models import PrepareType
 from cvargparse import Arg
 from cvargparse import BaseParser
+
 from cvfinetune.parser.utils import parser_extender
 from cvmodelz.models import ModelFactory
+from cvmodelz.models import PrepareType
 
 @parser_extender
 def add_model_args(parser: BaseParser, *,
@@ -27,11 +28,13 @@ def add_model_args(parser: BaseParser, *,
 			choices=["imagenet", "inat"],
 			help="type of model pre-training"),
 
-		Arg("--input_size", type=int, nargs="+", default=0,
-			help="overrides default input size of the model, if greater than 0"),
+		Arg("--input_size", default="infer",
+			help="overrides default input size of the model, if greater than 0. " \
+			"Set to \"infer\" to guess the input size from the model type."),
 
-		Arg("--part_input_size", type=int, nargs="+", default=0,
-			help="overrides default input part size of the model, if greater than 0"),
+		Arg("--part_input_size", default="infer",
+			help="overrides default input part size of the model, if greater than 0. " \
+			"Set to \"infer\" to guess the input size from the model type."),
 
 		PrepareType.as_arg("prepare_type",
 			help_text="type of image preprocessing"),
@@ -39,10 +42,10 @@ def add_model_args(parser: BaseParser, *,
 		PoolingType.as_arg("pooling",
 			help_text="type of pre-classification pooling"),
 
-		Arg("--load", type=str,
+		Arg("--load",
 			help="ignore weights and load already fine-tuned model (classifier will NOT be re-initialized and number of classes will be unchanged)"),
 
-		Arg("--weights", type=str,
+		Arg("--weights",
 			help="ignore default weights and load already pre-trained model (classifier will be re-initialized and number of classes will be changed)"),
 
 		Arg("--headless", action="store_true",
@@ -51,7 +54,7 @@ def add_model_args(parser: BaseParser, *,
 		Arg("--load_strict", action="store_true",
 			help="load weights in a strict mode"),
 
-		Arg("--load_path", type=str, default="",
+		Arg("--load_path", default="",
 			help="load path within the weights archive"),
 	]
 

+ 23 - 1
cvfinetune/parser/utils.py

@@ -102,4 +102,26 @@ def get_n_classes(weights_file: str, load_path: str = "",
 	return -1
 
 
-
+def parse_input_size(args, default: int = 224, attr: str = "input_size"):
+	model_type = args.model_type.lower()
+	value = getattr(args, attr)
+	if value == "infer":
+		if "resnet" in model_type:
+			size = 224
+		elif "inception" in model_type:
+			size = 299
+		else:
+			size = default
+
+	elif value == "infer_big":
+		if "resnet" in model_type:
+			size = 448
+		elif "inception" in model_type:
+			# alternatives: 453, 552, 585
+			size = 427
+		else:
+			size = default * 2
+	else:
+		size = int(value)
+	setattr(args, attr, size)
+	return (size, size)

+ 1 - 6
examples/basic/main.py

@@ -1,11 +1,6 @@
 #!/usr/bin/env python
 if __name__ != '__main__': raise Exception("Do not import me!")
 
-import socket
-if socket.gethostname() != "sigma25":
-	import matplotlib
-	matplotlib.use('Agg')
-
 import chainer
 import logging
 
@@ -25,7 +20,7 @@ def main(args):
 		chainer.set_debug(args.debug)
 		logging.warning("DEBUG MODE ENABLED!")
 
-	factory = FinetunerFactory.new(mpi=False)
+	factory = FinetunerFactory(mpi=False)
 
 	tuner = factory(args,
 		classifier_cls=Classifier,

+ 1 - 29
examples/basic/utils/parser.py

@@ -1,7 +1,4 @@
-import os
-
-from cvargparse import GPUParser, Arg
-from chainer_addons.links import PoolingType
+from cvargparse import Arg
 
 from cvfinetune.parser import default_factory
 from cvfinetune import parser as parser_module
@@ -13,31 +10,6 @@ def parse_args():
 
 			Arg("--pretrained_on", choices=["inat", "imagenet"],
 				help="network pretraining"),
-
-
-			# Arg("--normalize", action="store_true",
-			# 	help="normalize features after cbil- or alpha-poolings"),
-
-			# Arg("--subset", "-s", type=int, nargs="*", default=[-1], help="select specific classes"),
-			# Arg("--no_sacred", action="store_true", help="do save outputs to sacred"),
-
-			# Arg("--use_parts", action="store_true",
-			# 	help="use parts, if present"),
-			# Arg("--simple_parts", action="store_true",
-			# 	help="use simple parts classifier, that only concatenates the features"),
-			# Arg("--no_global", action="store_true",
-			# 	help="use parts only, no global feature"),
-
-
-			# Arg("--parts_in_bb", action="store_true", help="take only uniform regions where the centers are inside the bounding box"),
-
-			# Arg("--rnd_select", action="store_true", help="hide random uniform regions of the image"),
-			# Arg("--recurrent", action="store_true", help="observe all parts in recurrent manner instead of the whole image at once"),
-
-			# ## AlphaPooling options
-			# Arg("--init_alpha", type=int, default=1, help="initial parameter for alpha pooling"),
-			# Arg("--kappa", type=float, default=1., help="Learning rate factor for alpha pooling"),
-			# Arg("--switch_epochs", type=int, default=0, help="train alpha pooling layer and the rest of the network alternating")
 		])
 	)
 

+ 0 - 5
examples/fve_example/main.py

@@ -1,11 +1,6 @@
 #!/usr/bin/env python
 if __name__ != '__main__': raise Exception("Do not import me!")
 
-import socket
-if socket.gethostname() != "sigma25":
-	import matplotlib
-	matplotlib.use('Agg')
-
 import chainer
 import logging
 

+ 1 - 0
requirements.txt

@@ -15,6 +15,7 @@ chainercv2
 cvargparse~=0.3
 cvdatasets~=0.9
 chainer_addons~=0.9
+cvmodelz~=0.4.0
 
 wandb
 pymongo

+ 2 - 6
setup.py

@@ -1,16 +1,12 @@
 #!/usr/bin/env python
 
-import os
-import pkg_resources
-import sys
-
 from setuptools import setup, find_packages
 from pathlib import Path
 
 try: # for pip >= 10
 	from pip._internal.req import parse_requirements
 except ImportError: # for pip <= 9.0.3
-	from pip.req import parse_requirements
+	from pip.req import parse_requirements  # noqa: F401
 
 pkg_name = "cvfinetune"
 
@@ -24,7 +20,7 @@ install_requires = [line.strip() for line in open("requirements.txt").readlines(
 setup(
 	name='cvfinetune',
 	python_requires=">3.7",
-	version=__version__,
+	version=__version__,  # noqa: F821
 	description='Fine-tune framework based on chainer',
 	author='Dimitri Korsch',
 	author_email='korschdima@gmail.com',