|
@@ -12,6 +12,9 @@ from chainer.optimizer_hooks import WeightDecay
|
|
from chainer.serializers import save_npz
|
|
from chainer.serializers import save_npz
|
|
from chainer.training import extensions
|
|
from chainer.training import extensions
|
|
|
|
|
|
|
|
+from chainercv2.model_provider import get_model
|
|
|
|
+from chainercv2.models import model_store
|
|
|
|
+
|
|
from chainer_addons.functions import smoothed_cross_entropy
|
|
from chainer_addons.functions import smoothed_cross_entropy
|
|
from chainer_addons.models import Classifier
|
|
from chainer_addons.models import Classifier
|
|
from chainer_addons.models import ModelType
|
|
from chainer_addons.models import ModelType
|
|
@@ -45,12 +48,18 @@ class _ModelMixin(abc.ABC):
|
|
model wrapping around a classifier and model weights loading.
|
|
model wrapping around a classifier and model weights loading.
|
|
"""
|
|
"""
|
|
|
|
|
|
- def __init__(self, classifier_cls, classifier_kwargs={}, model_kwargs={}, *args, **kwargs):
|
|
|
|
- super(_ModelMixin, self).__init__(*args, **kwargs)
|
|
|
|
|
|
+ def __init__(self, opts, classifier_cls, classifier_kwargs={}, model_kwargs={}, *args, **kwargs):
|
|
|
|
+ super(_ModelMixin, self).__init__(opts=opts, *args, **kwargs)
|
|
self.classifier_cls = classifier_cls
|
|
self.classifier_cls = classifier_cls
|
|
self.classifier_kwargs = classifier_kwargs
|
|
self.classifier_kwargs = classifier_kwargs
|
|
|
|
+ self.model_type = opts.model_type
|
|
self.model_kwargs = model_kwargs
|
|
self.model_kwargs = model_kwargs
|
|
|
|
|
|
|
|
+
|
|
|
|
+ @property
|
|
|
|
+ def model_info(self):
|
|
|
|
+ return self.data_info.MODELS[self.model_type]
|
|
|
|
+
|
|
def wrap_model(self, opts):
|
|
def wrap_model(self, opts):
|
|
|
|
|
|
clf_class, kwargs = self.classifier_cls, self.classifier_kwargs
|
|
clf_class, kwargs = self.classifier_cls, self.classifier_kwargs
|
|
@@ -119,8 +128,15 @@ class _ModelMixin(abc.ABC):
|
|
def init_model(self, opts):
|
|
def init_model(self, opts):
|
|
"""creates backbone CNN model. This model is wrapped around the classifier later"""
|
|
"""creates backbone CNN model. This model is wrapped around the classifier later"""
|
|
|
|
|
|
|
|
+ if self.model_type.startswith("cv2_"):
|
|
|
|
+ model_type = args.model_type.split("cv2_")[-1]
|
|
|
|
+ else:
|
|
|
|
+ model_type = self.model_info.class_key
|
|
|
|
+
|
|
|
|
+ # model = get_model(model_type, pretrained=False)
|
|
|
|
+
|
|
self.model = ModelType.new(
|
|
self.model = ModelType.new(
|
|
- model_type=self.model_info.class_key,
|
|
|
|
|
|
+ model_type=model_type,
|
|
input_size=Size(opts.input_size),
|
|
input_size=Size(opts.input_size),
|
|
**self.model_kwargs,
|
|
**self.model_kwargs,
|
|
)
|
|
)
|
|
@@ -178,8 +194,10 @@ class _DatasetMixin(abc.ABC):
|
|
dataset and iterator creation.
|
|
dataset and iterator creation.
|
|
"""
|
|
"""
|
|
|
|
|
|
- def __init__(self, dataset_cls, dataset_kwargs_factory, *args, **kwargs):
|
|
|
|
- super(_DatasetMixin, self).__init__(*args, **kwargs)
|
|
|
|
|
|
+ def __init__(self, opts, dataset_cls, dataset_kwargs_factory, *args, **kwargs):
|
|
|
|
+ super(_DatasetMixin, self).__init__(opts=opts, *args, **kwargs)
|
|
|
|
+ self.annot = None
|
|
|
|
+ self.dataset_type = opts.dataset
|
|
self.dataset_cls = dataset_cls
|
|
self.dataset_cls = dataset_cls
|
|
self.dataset_kwargs_factory = dataset_kwargs_factory
|
|
self.dataset_kwargs_factory = dataset_kwargs_factory
|
|
|
|
|
|
@@ -187,6 +205,15 @@ class _DatasetMixin(abc.ABC):
|
|
def n_classes(self):
|
|
def n_classes(self):
|
|
return self.ds_info.n_classes + self.dataset_cls.label_shift
|
|
return self.ds_info.n_classes + self.dataset_cls.label_shift
|
|
|
|
|
|
|
|
+ @property
|
|
|
|
+ def data_info(self):
|
|
|
|
+ assert self.annot is not None, "annot attribute was not set!"
|
|
|
|
+ return self.annot.info
|
|
|
|
+
|
|
|
|
+ @property
|
|
|
|
+ def ds_info(self):
|
|
|
|
+ return self.data_info.DATASETS[self.dataset_type]
|
|
|
|
+
|
|
def new_dataset(self, opts, size, part_size, subset):
|
|
def new_dataset(self, opts, size, part_size, subset):
|
|
"""Creates a dataset for a specific subset and certain options"""
|
|
"""Creates a dataset for a specific subset and certain options"""
|
|
if self.dataset_kwargs_factory is not None and callable(self.dataset_kwargs_factory):
|
|
if self.dataset_kwargs_factory is not None and callable(self.dataset_kwargs_factory):
|
|
@@ -208,16 +235,11 @@ class _DatasetMixin(abc.ABC):
|
|
logging.info("Loaded {} images".format(len(ds)))
|
|
logging.info("Loaded {} images".format(len(ds)))
|
|
return ds
|
|
return ds
|
|
|
|
|
|
|
|
+
|
|
def init_annotations(self, opts):
|
|
def init_annotations(self, opts):
|
|
"""Reads annotations and creates annotation instance, which holds important infos about the dataset"""
|
|
"""Reads annotations and creates annotation instance, which holds important infos about the dataset"""
|
|
|
|
|
|
self.annot = AnnotationType.new_annotation(opts, load_strict=False)
|
|
self.annot = AnnotationType.new_annotation(opts, load_strict=False)
|
|
-
|
|
|
|
- self.data_info = self.annot.info
|
|
|
|
- self.model_info = self.data_info.MODELS[opts.model_type]
|
|
|
|
- self.ds_info = self.data_info.DATASETS[opts.dataset]
|
|
|
|
- # self.part_info = self.data_info.PART_TYPES[opts.parts]
|
|
|
|
-
|
|
|
|
self.dataset_cls.label_shift = opts.label_shift
|
|
self.dataset_cls.label_shift = opts.label_shift
|
|
|
|
|
|
|
|
|