|
@@ -72,7 +72,7 @@ class BaseModel(abc.ABC, chainer.Chain):
|
|
|
w_shape = (n_classes, feat_size or clf_layer.W.shape[1])
|
|
|
dtype = clf_layer.W.dtype
|
|
|
|
|
|
- clf_layer.in_size, clf_layer.in_size = w_shape
|
|
|
+ clf_layer.out_size, clf_layer.in_size = w_shape
|
|
|
clf_layer.W.data = np.zeros(w_shape, dtype=dtype)
|
|
|
clf_layer.b.data = np.zeros(w_shape[0], dtype=dtype)
|
|
|
|