|
@@ -17,6 +17,7 @@ class Classifier(chainer.Chain):
|
|
only_head: bool = False,
|
|
only_head: bool = False,
|
|
):
|
|
):
|
|
super().__init__()
|
|
super().__init__()
|
|
|
|
+ self._only_head = only_head
|
|
self.layer_name = layer_name or model.clf_layer_name
|
|
self.layer_name = layer_name or model.clf_layer_name
|
|
self.loss_func = loss_func
|
|
self.loss_func = loss_func
|
|
|
|
|