classifier.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. import logging
  2. import chainer
  3. import chainer.functions as F
  4. import chainer.links as L
  5. from chainer_addons.links.fisher_encoding import FVLayer
  6. from finetune.classifier import SeparateModelClassifier
  7. class FVEClassifier(SeparateModelClassifier):
  8. def __init__(self, n_comps=2**4, fv_insize=256, alpha=0.99,
  9. *args, **kwargs):
  10. super(FVEClassifier, self).__init__(*args, **kwargs)
  11. with self.init_scope():
  12. if fv_insize < 1:
  13. self.pre_fv = F.identity
  14. fv_insize = self.model.meta.feature_size
  15. else:
  16. self.pre_fv = L.Convolution2D(
  17. self.model.meta.feature_size,
  18. fv_insize,
  19. ksize=1)
  20. self.pre_fv_bn = L.BatchNormalization(fv_insize)
  21. self.fv_encoding = FVLayer(
  22. fv_insize, n_comps,
  23. alpha=alpha)
  24. self.fv_insize = fv_insize
  25. self.n_comps = n_comps
  26. def __call__(self, *inputs):
  27. parts, X, y = inputs
  28. n, t, c, h, w = parts.shape
  29. _parts = parts.reshape(n*t, c, h, w)
  30. part_convs, _ = self.model(_parts, layer_name=self.model.meta.conv_map_layer)
  31. part_local_feats = self.pre_fv_bn(self.pre_fv(part_convs))
  32. n0, n_feats, conv_h, conv_w = part_local_feats.shape
  33. part_local_feats = F.reshape(part_local_feats, (n, t, n_feats, conv_h, conv_w))
  34. # N x T x C x H x W -> N x T x H x W x C
  35. part_local_feats = F.transpose(part_local_feats, (0, 1, 3, 4, 2))
  36. # N x T x H x W x C -> N x T*H*W x C
  37. part_local_feats = F.reshape(part_local_feats, (n, t*conv_h*conv_w, n_feats))
  38. logits = self.fv_encoding(part_local_feats)
  39. logL, _ = self.fv_encoding.log_proba(part_local_feats, weighted=True)
  40. # may be used later to maximize the log-likelihood
  41. self.neg_logL = -F.mean(logL)
  42. # avarage over all local features
  43. avg_logL = F.logsumexp(logL) - self.xp.log(logL.size)
  44. part_pred = self.model.clf_layer(logits)
  45. part_loss = self.loss(part_pred, y)
  46. part_accu = self.model.accuracy(part_pred, y)
  47. self.report(
  48. part_accu=part_accu,
  49. part_loss=part_loss,
  50. logL=avg_logL
  51. )
  52. glob_loss, glob_pred = self.predict_global(X, y)
  53. pred = part_pred + glob_pred
  54. accuracy = self.model.accuracy(pred, y)
  55. loss = self.loss(pred, y)
  56. self.report(
  57. loss = loss.array,
  58. accuracy = accuracy.array,
  59. )
  60. return loss
  61. def predict_global(self, X, y):
  62. glob_pred, _ = self.separate_model(X)
  63. glob_loss = self.loss(glob_pred, y)
  64. glob_accu = self.separate_model.accuracy(glob_pred, y)
  65. self.report(
  66. glob_loss = glob_loss.data,
  67. glob_accu = glob_accu.data,
  68. )
  69. return glob_loss, glob_pred
  70. @property
  71. def feat_size(self):
  72. return self.model.meta.feature_size
  73. @property
  74. def output_size(self):
  75. return self.fv_insize * self.n_comps * 2