Pārlūkot izejas kodu

added load_path argument. fixed a minor bug

Dimitri Korsch 4 gadi atpakaļ
vecāks
revīzija
452a7ee5e8
3 mainītis faili ar 49 papildinājumiem un 17 dzēšanām
  1. 33 11
      cvfinetune/finetuner/base.py
  2. 14 6
      cvfinetune/parser/model_args.py
  3. 2 0
      requirements.txt

+ 33 - 11
cvfinetune/finetuner/base.py

@@ -12,6 +12,9 @@ from chainer.optimizer_hooks import WeightDecay
 from chainer.serializers import save_npz
 from chainer.training import extensions
 
+from chainercv2.model_provider import get_model
+from chainercv2.models import model_store
+
 from chainer_addons.functions import smoothed_cross_entropy
 from chainer_addons.models import Classifier
 from chainer_addons.models import ModelType
@@ -45,12 +48,18 @@ class _ModelMixin(abc.ABC):
 	model wrapping around a classifier and model weights loading.
 	"""
 
-	def __init__(self, classifier_cls, classifier_kwargs={}, model_kwargs={}, *args, **kwargs):
-		super(_ModelMixin, self).__init__(*args, **kwargs)
+	def __init__(self, opts, classifier_cls, classifier_kwargs={}, model_kwargs={}, *args, **kwargs):
+		super(_ModelMixin, self).__init__(opts=opts, *args, **kwargs)
 		self.classifier_cls = classifier_cls
 		self.classifier_kwargs = classifier_kwargs
+		self.model_type = opts.model_type
 		self.model_kwargs = model_kwargs
 
+
+	@property
+	def model_info(self):
+		return self.data_info.MODELS[self.model_type]
+
 	def wrap_model(self, opts):
 
 		clf_class, kwargs = self.classifier_cls, self.classifier_kwargs
@@ -119,8 +128,15 @@ class _ModelMixin(abc.ABC):
 	def init_model(self, opts):
 		"""creates backbone CNN model. This model is wrapped around the classifier later"""
 
+		if self.model_type.startswith("cv2_"):
+			model_type = args.model_type.split("cv2_")[-1]
+		else:
+			model_type = self.model_info.class_key
+
+			# model = get_model(model_type, pretrained=False)
+
 		self.model = ModelType.new(
-			model_type=self.model_info.class_key,
+			model_type=model_type,
 			input_size=Size(opts.input_size),
 			**self.model_kwargs,
 		)
@@ -178,8 +194,10 @@ class _DatasetMixin(abc.ABC):
 		dataset and iterator creation.
 	"""
 
-	def __init__(self, dataset_cls, dataset_kwargs_factory, *args, **kwargs):
-		super(_DatasetMixin, self).__init__(*args, **kwargs)
+	def __init__(self, opts, dataset_cls, dataset_kwargs_factory, *args, **kwargs):
+		super(_DatasetMixin, self).__init__(opts=opts, *args, **kwargs)
+		self.annot = None
+		self.dataset_type = opts.dataset
 		self.dataset_cls = dataset_cls
 		self.dataset_kwargs_factory = dataset_kwargs_factory
 
@@ -187,6 +205,15 @@ class _DatasetMixin(abc.ABC):
 	def n_classes(self):
 		return self.ds_info.n_classes + self.dataset_cls.label_shift
 
+	@property
+	def data_info(self):
+		assert self.annot is not None, "annot attribute was not set!"
+		return self.annot.info
+
+	@property
+	def ds_info(self):
+		return self.data_info.DATASETS[self.dataset_type]
+
 	def new_dataset(self, opts, size, part_size, subset):
 		"""Creates a dataset for a specific subset and certain options"""
 		if self.dataset_kwargs_factory is not None and callable(self.dataset_kwargs_factory):
@@ -208,16 +235,11 @@ class _DatasetMixin(abc.ABC):
 		logging.info("Loaded {} images".format(len(ds)))
 		return ds
 
+
 	def init_annotations(self, opts):
 		"""Reads annotations and creates annotation instance, which holds important infos about the dataset"""
 
 		self.annot = AnnotationType.new_annotation(opts, load_strict=False)
-
-		self.data_info = self.annot.info
-		self.model_info = self.data_info.MODELS[opts.model_type]
-		self.ds_info = self.data_info.DATASETS[opts.dataset]
-		# self.part_info = self.data_info.PART_TYPES[opts.parts]
-
 		self.dataset_cls.label_shift = opts.label_shift
 
 

+ 14 - 6
cvfinetune/parser/model_args.py

@@ -7,19 +7,27 @@ from cvfinetune.parser.utils import parser_extender
 from chainer_addons.links import PoolingType
 from chainer_addons.models import PrepareType
 
+class ModelChoices(object):
+
+	def __init__(self, choices=[]):
+		super(ModelChoices, self).__init__()
+		self.choices = choices
+
+	def __contains__(self, value):
+		return value.startswith("cv2_") or value in self.choices
+
+	def __iter__(self):
+		return iter(self.choices + ["cv2_<any other model>"])
+
 @parser_extender
 def add_model_args(parser):
 
 	info_file = get_info_file()
-
-	if info_file is None:
-		model_type_choices = None
-	else:
-		model_type_choices = info_file.MODELS.keys()
+	choices = None if info_file is None else info_file.MODELS.keys()
 
 	_args = [
 		Arg("--model_type", "-mt",
-			default="resnet", choices=model_type_choices,
+			default="cv2_resnet50", choices=ModelChoices(choices),
 			help="type of the model"),
 
 		Arg("--input_size", type=int, nargs="+", default=0,

+ 2 - 0
requirements.txt

@@ -8,6 +8,8 @@ simplejson~=3.14
 sacred~=0.7
 
 chainer>=4.2.0,<8.0
+chainercv~=0.13
+chainercv2~=0.0
 # cupy-cuda101>=4.2.0,<7.0
 
 # my own packages