helperFunctions.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. import numpy
  2. import math
  3. import ConfigParser
  4. import os
  5. import sys
  6. import subprocess
  7. import scipy.sparse.linalg
  8. import time
  9. def confusionMatrix(expected, predicted):
  10. #print expected.shape, predicted.shape
  11. #print numpy.unique(numpy.asarray(expected))
  12. yUni = numpy.asmatrix(numpy.unique(numpy.asarray(expected)))
  13. confMat = numpy.asmatrix(numpy.zeros((yUni.shape[1], yUni.shape[1])))
  14. for expc, pred in zip(expected, predicted):
  15. confMat[numpy.where(yUni == expc)[1], numpy.where(yUni == pred)[1]] += 1
  16. if numpy.sum(confMat) != max(predicted.shape):
  17. print numpy.sum(confMat), '!=', max(predicted.shape), '(', predicted.shape, ')'
  18. print 'cls expected: ', numpy.unique(numpy.asarray(expected))
  19. print 'cls predicted:', numpy.unique(numpy.asarray(predicted))
  20. raise Exception('# predicted cls > # expected cls')
  21. return confMat
  22. def getAvgAcc(confMat):
  23. return numpy.mean(numpy.diagonal(confMat)/numpy.sum(confMat, axis=1).T)
  24. def getConfig(pathtoConfig, section, option, default=None, dtype='str', verbose=False):
  25. # set default
  26. value = default
  27. defaultUsed = True
  28. # check if file is available
  29. if pathtoConfig is not None and os.path.isfile(pathtoConfig):
  30. # init
  31. config = ConfigParser.ConfigParser()
  32. configFile = open(pathtoConfig)
  33. config.readfp(configFile)
  34. configFile.close()
  35. # check if section and option is available
  36. if config.has_section(section) and config.has_option(section, option):
  37. # get requested type
  38. if dtype == 'str':
  39. value = config.get(section, option)
  40. elif dtype == 'int':
  41. value = config.getint(section, option)
  42. elif dtype == 'float':
  43. value = config.getfloat(section, option)
  44. elif dtype == 'bool':
  45. value = config.getboolean(section, option)
  46. elif dtype == 'strList':
  47. value = config.get(section, option).split(',')
  48. elif dtype == 'intList':
  49. value = [int(entry) for entry in config.get(section, option).split(',')]
  50. elif dtype == 'floatList':
  51. value = [float(entry) for entry in config.get(section, option).split(',')]
  52. elif dtype == 'boolList':
  53. value = [bool(entry) for entry in config.get(section, option).split(',')]
  54. else:
  55. raise Exception('Unknown dtype!')
  56. defaultUsed = False
  57. # print config
  58. if verbose:
  59. aux = ''
  60. if 'List' in dtype and len(value) > 0:
  61. aux = '| entryDtype:' + str(type(value[0]))
  62. print 'default:', defaultUsed, '| section:', section, '| option:', option, '| value:', value, '| dtype:', type(value), aux
  63. # return
  64. return value
  65. def getGitHash(gitPath=os.path.dirname(os.path.abspath(__file__))):
  66. curDir = os.getcwd()
  67. os.chdir(gitPath)
  68. gitHash = subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD'], stderr=subprocess.STDOUT).strip()
  69. os.chdir(curDir)
  70. return gitHash
  71. def getYfromYbin(yBin, yUni):
  72. y = numpy.asmatrix(numpy.zeros((yBin.shape[0],1), dtype=numpy.int))
  73. for idx in range(yBin.shape[0]):
  74. y[idx,0] = int(yUni[0,numpy.ravel(yBin[idx,:] == 1)])
  75. return y
  76. def solveW(x, y, sigmaN, initW=None, maxIter=None):
  77. linOpX = scipy.sparse.linalg.aslinearoperator(x)
  78. w = numpy.asmatrix(numpy.empty((x.shape[1], y.shape[1])))
  79. def matvecFunc(curW):
  80. return linOpX.rmatvec(linOpX.matvec(curW)) + sigmaN*curW
  81. for clsIdx in range(y.shape[1]):
  82. if initW is not None:
  83. initWbin = numpy.ravel(numpy.asarray(initW[:,clsIdx])).T
  84. else:
  85. initWbin = None
  86. linOpW = scipy.sparse.linalg.LinearOperator((x.shape[1], x.shape[1]), matvec=matvecFunc, dtype=x.dtype)
  87. solvedWbin,info = scipy.sparse.linalg.cg(linOpW, linOpX.rmatvec(y[:, clsIdx]), x0=initWbin, maxiter=maxIter)
  88. if info != 0 and maxIter is None:
  89. print ''
  90. print 'WARNING: cg not converged!'
  91. print ''
  92. w[:,clsIdx] = numpy.asmatrix(solvedWbin).T
  93. return w
  94. def getClsWeights(y, yUni):
  95. clsWeights = numpy.empty(yUni.shape[1])
  96. for clsIdx in range(yUni.shape[1]):
  97. clsWeights[clsIdx] = y.shape[0] / (float(yUni.shape[1])*numpy.argwhere(y==yUni[0,clsIdx]).shape[0])
  98. return clsWeights
  99. def writeNotification(notificationPath, statusStr):
  100. try:
  101. open(os.path.join(notificationPath, statusStr), 'a').close()
  102. except:
  103. print ''
  104. print 'ERROR: writing notification to {} failed!'.format(notificationPath)
  105. sys.stdout.flush()
  106. def getReweightDiagMat(y, yUni=None, clsWeights=None):
  107. if yUni is None:
  108. yUni = numpy.asmatrix(numpy.unique(numpy.asarray(y)))
  109. if clsWeights is None:
  110. clsWeights = getClsWeights(y, yUni)
  111. sampleWeights = clsWeights[numpy.searchsorted(numpy.ravel(numpy.asarray(yUni)), numpy.ravel(numpy.asarray(y)))]
  112. sampleWeights = numpy.prod([sampleWeights], axis=0)
  113. return numpy.asmatrix(numpy.diag(numpy.sqrt(sampleWeights*numpy.ones(y.shape[0]))))
  114. def showProgressBarTerminal(current, total, pre):
  115. sys.stdout.write('\r%s %0.2f %%'%(pre,(float(current)/float(total))*100.0))
  116. sys.stdout.flush()