reweightWlinGP.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. #! /usr/bin/python
  2. import numpy
  3. import scipy.special
  4. import sys
  5. import os
  6. sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)),os.pardir))
  7. import helperFunctions
  8. class Reweighter:
  9. def __init__(self,
  10. sigmaN = 0.00178,
  11. configFile=None):
  12. self.sigmaN = helperFunctions.getConfig(configFile, 'reweighting', 'sigmaN', sigmaN, 'float', True)
  13. self.w = [] # .shape = (feat dim, number of unique classes)
  14. self.X = []
  15. self.y = []
  16. def checkModel(self):
  17. if not numpy.all(numpy.isfinite(self.w)):
  18. raise Exception('not numpy.all(numpy.isfinite(self.w))')
  19. if not numpy.all(numpy.isfinite(self.X)):
  20. raise Exception('not numpy.all(numpy.isfinite(self.X))')
  21. if not numpy.all(numpy.isfinite(self.y)):
  22. raise Exception('not numpy.all(numpy.isfinite(self.y))')
  23. # X.shape = (number of samples, feat dim), y.shape = (number of samples, 1)
  24. def train(self, X, y):
  25. # save stuff
  26. self.X = X
  27. self.y = numpy.asmatrix(y == -1, dtype=numpy.int)*2 - 1
  28. # sample reweighting
  29. rewDiagMat = helperFunctions.getReweightDiagMat(self.y, numpy.asmatrix(numpy.unique(numpy.asarray(self.y))))
  30. rewX = numpy.dot(rewDiagMat,self.X)
  31. rewY = numpy.dot(rewDiagMat,self.y)
  32. # calculate w
  33. self.w = helperFunctions.solveW(rewX, rewY, self.sigmaN)
  34. self.checkModel()
  35. # x.shape = (1, feat dim), y.shape = (1, 1)
  36. def update(self, x, y):
  37. # relabel
  38. relY = numpy.asmatrix(y == -1, dtype=numpy.int)*2 - 1
  39. # update samples
  40. self.X = numpy.append(self.X, x, axis=0)
  41. self.y = numpy.append(self.y, relY, axis=0)
  42. # sample reweighting
  43. rewDiagMat = helperFunctions.getReweightDiagMat(self.y, numpy.asmatrix(numpy.unique(numpy.asarray(self.y))))
  44. rewX = numpy.dot(rewDiagMat,self.X)
  45. rewY = numpy.dot(rewDiagMat,self.y)
  46. # calculate w
  47. self.w = helperFunctions.solveW(rewX, rewY, self.sigmaN, self.w)
  48. self.checkModel()
  49. # X.shape = (number of samples, feat dim)
  50. def infer(self, x):
  51. pred = numpy.asmatrix(x*self.w)
  52. if not numpy.all(numpy.isfinite(pred)):
  53. raise Exception('not numpy.all(numpy.isfinite(pred))')
  54. return pred
  55. # X.shape = (number of samples, feat dim)
  56. def reweight(self, alScores, x):
  57. clsScores = self.infer(x)
  58. # sample reweighting
  59. rewDiagMat = helperFunctions.getReweightDiagMat(self.y, numpy.asmatrix(numpy.unique(numpy.asarray(self.y))))
  60. rewX = numpy.dot(rewDiagMat,self.X)
  61. var = numpy.sum(numpy.multiply(x,(self.sigmaN*numpy.linalg.solve(numpy.add(numpy.dot(rewX.T,rewX), numpy.identity(x.shape[1])*self.sigmaN), x.T)).T),axis=1)
  62. probs = 0.5 - 0.5*scipy.special.erf(clsScores*(-1)/numpy.sqrt(2*var))
  63. newAlScores = numpy.multiply(alScores,(1 - probs))
  64. if not numpy.all(numpy.isfinite(newAlScores)):
  65. raise Exception('not numpy.all(numpy.isfinite(newAlScores))')
  66. return newAlScores