Browse Source

updated pre-trained VGG net definition

Dimitri Korsch 2 năm trước cách đây
mục cha
commit
ff8358a565
1 tập tin đã thay đổi với 24 bổ sung16 xóa
  1. 24 16
      cvmodelz/models/pretrained/vgg.py

+ 24 - 16
cvmodelz/models/pretrained/vgg.py

@@ -1,26 +1,13 @@
 from chainer import links as L
-from chainer.links.model.vision.vgg import prepare
+from chainer.links.model.vision.vgg import prepare as vgg_prepare
 from chainer.links.model.vision.vgg import _max_pooling_2d
 
 from cvmodelz.models.meta_info import ModelInfo
 from cvmodelz.models.pretrained.base import PretrainedModelMixin
 
-def _vgg_meta(final_conv_layer):
-	return ModelInfo(
-		name="VGG",
-		input_size=224,
-		feature_size=4096,
-		n_conv_maps=512,
-
-		conv_map_layer=final_conv_layer,
-		feature_layer="fc7",
-
-		classifier_layers=["fc6", "fc7", "fc8"],
-
-		prepare_func=prepare,
-	)
 
 class BaseVGG(PretrainedModelMixin):
+
 	def __init__(self, *args, **kwargs):
 		kwargs["pooling"] = kwargs.get("pooling", _max_pooling_2d)
 		super().__init__(*args, **kwargs)
@@ -30,7 +17,28 @@ class BaseVGG(PretrainedModelMixin):
 		return super().functions
 
 	def init_model_info(self):
-		self.meta = _vgg_meta(self.final_conv_layer)
+		self.meta = ModelInfo(
+			name="VGG",
+			input_size=224,
+			feature_size=4096,
+			n_conv_maps=512,
+
+			conv_map_layer=self.final_conv_layer,
+			feature_layer="fc7",
+
+			classifier_layers=["fc6", "fc7", "fc8"],
+
+			prepare_func=self.prepare,
+		)
+
+	def prepare(self, x, size=None, *, swap_channels=True, keep_ratio=True):
+		x = vgg_prepare(x, size=size)
+
+		# if not desired, we need to undo it
+		if not swap_channels:
+			x = x[:, :, ::-1]
+
+		return x
 
 class VGG19(BaseVGG, L.VGG19Layers):
 	final_conv_layer = "conv5_4"