RunExperiment.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. #! /usr/bin/python
  2. import scipy.io
  3. import numpy
  4. import time
  5. import socket
  6. import datetime
  7. ###
  8. import sys
  9. import os
  10. sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)),os.pardir))
  11. ###
  12. import helperFunctions
  13. import methodSelection
  14. import datasetAcquisition
  15. ###
  16. if len(sys.argv) != 2:
  17. raise Exception('No config file given!')
  18. print ''
  19. print ' -- config I -- '
  20. print ''
  21. defaultFname = os.path.join(os.path.dirname(sys.argv[1]),os.pardir,'setup.cfg')
  22. if not os.path.isfile(defaultFname):
  23. defaultFname = sys.argv[1]
  24. setupFileName = helperFunctions.getConfig(sys.argv[1], 'experiment', 'setupFileName', defaultFname, 'str', True)
  25. numTasks = helperFunctions.getConfig(setupFileName, 'experiment', 'numTasks', 3, 'int', True)
  26. numRndInits = helperFunctions.getConfig(setupFileName, 'experiment', 'numRndInits', 3, 'int', True)
  27. numSteps = helperFunctions.getConfig(setupFileName, 'experiment', 'numSteps', 500, 'int', True)
  28. numCls = helperFunctions.getConfig(setupFileName, 'experiment', 'numCls', 80, 'int', True)
  29. forbiddenCls = helperFunctions.getConfig(setupFileName, 'experiment', 'forbiddenCls', [], 'intList', True)
  30. notificationPath = helperFunctions.getConfig(setupFileName, 'experiment', 'notificationPath', None, 'str', True)
  31. writeAttemptsNmb = helperFunctions.getConfig(setupFileName, 'experiment', 'writeAttemptsNmb', 5, 'int', True)
  32. writeAttemptsDelay = helperFunctions.getConfig(setupFileName, 'experiment', 'writeAttemptsDelay', 30, 'int', True)
  33. print ''
  34. print ' -- config II -- '
  35. print ''
  36. rejectNoise = helperFunctions.getConfig(sys.argv[1], 'experiment', 'rejectNoise', True, 'bool', True)
  37. continueExperiment = helperFunctions.getConfig(sys.argv[1], 'experiment', 'continueExperiment', False, 'bool', True)
  38. useApproximation = helperFunctions.getConfig(sys.argv[1], 'experiment', 'prepareApproximation', False, 'bool', True)
  39. alMethod = helperFunctions.getConfig(sys.argv[1], 'activeLearning', 'method', None, 'str', True)
  40. rewMethod = helperFunctions.getConfig(sys.argv[1], 'reweighting', 'method', 'None', 'str', True)
  41. startTaskIdx = helperFunctions.getConfig(sys.argv[1], 'experiment', 'startTaskIdx', 0, 'int', True)
  42. endTaskIdx = helperFunctions.getConfig(sys.argv[1], 'experiment', 'endTaskIdx', numTasks - 1, 'int', True)
  43. if (startTaskIdx != endTaskIdx) or (numTasks < 2):
  44. resultsFileName = helperFunctions.getConfig(sys.argv[1], 'experiment', 'resultsFileName', os.getcwd() + '/results.mat', 'str', True)
  45. else:
  46. resultsFileName = helperFunctions.getConfig(sys.argv[1], 'experiment', 'resultsFileName', os.getcwd() + '/results' + str(startTaskIdx) + '.mat', 'str', True)
  47. identifier = helperFunctions.getConfig(sys.argv[1], 'experiment', 'identifier', os.path.basename(os.path.dirname(resultsFileName)), 'str', True)
  48. print ''
  49. print 'host:', socket.gethostname()
  50. print 'pid:', os.getpid()
  51. print 'now:', datetime.datetime.strftime(datetime.datetime.now(), '%d.%m.%Y %H:%M:%S')
  52. print 'git:', helperFunctions.getGitHash()
  53. sys.stdout.flush()
  54. if identifier is None or resultsFileName is None or alMethod is None:
  55. raise Exception('ERROR: Config incomplete!')
  56. if not os.path.isdir(os.path.dirname(os.path.abspath(resultsFileName))) or not os.path.exists(os.path.dirname(os.path.abspath(resultsFileName))):
  57. raise Exception('ERROR: Results path does not exist!')
  58. if os.getcwd() != os.path.dirname(os.path.abspath(resultsFileName)):
  59. print ''
  60. print 'WARNING: current path != results path'
  61. print 'current:', os.getcwd()
  62. print 'rerults:', os.path.dirname(os.path.abspath(resultsFileName))
  63. ###
  64. if continueExperiment and os.path.isfile(resultsFileName):
  65. print ''
  66. print 'loading previous results ...'
  67. sys.stdout.flush()
  68. tmp = scipy.io.loadmat(resultsFileName)['results']
  69. values = tmp.item(0)
  70. names = list(tmp.dtype.names)
  71. confMats = numpy.asarray(values[names.index('confMats')], dtype=numpy.float)
  72. queriedIdxs = numpy.asarray(values[names.index('queriedIdxs')], dtype=numpy.float)
  73. #name = values[names.index('name')].item(0)
  74. identifier = values[names.index('identifier')].item(0)
  75. knownCls = numpy.asarray(values[names.index('knownCls')], dtype=numpy.float)
  76. timeNeeded = numpy.asarray(values[names.index('timeNeeded')], dtype=numpy.float)
  77. startTaskIdx = values[names.index('lastTaskIdx')].item(0)
  78. startRndInitIdx = values[names.index('lastRndInitIdx')].item(0) + 1
  79. else:
  80. queriedIdxs = numpy.zeros((numTasks,numRndInits,numSteps))
  81. timeNeeded = numpy.zeros((numTasks,numRndInits,numSteps))
  82. confMats = numpy.zeros((numTasks,numRndInits,numSteps + 1,numCls,numCls))
  83. knownCls = numpy.zeros((numTasks,numRndInits,numSteps + 1))
  84. startRndInitIdx = 0
  85. ###
  86. #numpy.random.seed(int(time.time()*1000.0))
  87. timePast = 0
  88. totalRuns = (endTaskIdx + 1)*numRndInits*numSteps - startTaskIdx*numRndInits*numSteps - startRndInitIdx*numSteps
  89. queriesPast = 0
  90. for taskIdx in range(startTaskIdx, endTaskIdx + 1):
  91. for rndInitIdx in range(startRndInitIdx, numRndInits):
  92. if notificationPath is not None:
  93. helperFunctions.writeNotification(notificationPath, 'status__' + identifier + '__' + socket.gethostname() + '__' + str(taskIdx) + '__' + str(rndInitIdx))
  94. print''
  95. print 'loading data ...'
  96. xTrain, yTrain, xPool, yPool, xTest, yTest = datasetAcquisition.readDataForInit(taskIdx, rndInitIdx, setupFileName)
  97. print''
  98. print 'training models ...'
  99. classifier = methodSelection.selectActiveLearning(alMethod, sys.argv[1])
  100. classifier.train(xTrain, yTrain)
  101. reweighter = methodSelection.selectReweighter(rewMethod, sys.argv[1])
  102. reweighter.train(xTrain, yTrain)
  103. sys.stdout.flush()
  104. pred = classifier.test(xTest)
  105. confMats[taskIdx,rndInitIdx,0,:,:] = helperFunctions.confusionMatrix(yTest, pred)
  106. knownCls[taskIdx,rndInitIdx,0] = classifier.yUni.shape[1]
  107. print ''
  108. print 'next task:', taskIdx, ', next rndInit:', rndInitIdx
  109. print 'xTrain: {}, yTrain: {} [#cls: {}], xPool: {}, yPool: {} [#cls: {}, #noise: {}], xTest: {}, yTest: {} [#cls: {}]'.format(xTrain.shape, yTrain.shape, len(numpy.unique(numpy.asarray(yTrain))), xPool.shape, yPool.shape, len(numpy.unique(numpy.asarray(yPool))), numpy.sum(yPool==-1), xTest.shape, yTest.shape, len(numpy.unique(numpy.asarray(yTest))))
  110. print 'initial acc:', helperFunctions.getAvgAcc(confMats[taskIdx,rndInitIdx,0,:,:]), ', initial knownCls:', int(knownCls[taskIdx,rndInitIdx,0])
  111. sys.stdout.flush()
  112. orgIdxs = numpy.asmatrix(range(1,yPool.shape[0] + 1))
  113. if useApproximation:
  114. print 'prepare approximation ...'
  115. print
  116. sys.stdout.flush()
  117. classifier.prepareApprox(xPool)
  118. for step in range(numSteps):
  119. t0 = time.time()
  120. alScores1 = classifier.getAlScores(xPool)
  121. alScores2 = reweighter.reweight(alScores1, xPool)
  122. chosenIdx = numpy.argmax(alScores2, axis=0).item(0)
  123. newX = xPool[chosenIdx,:]
  124. newY = yPool[chosenIdx,:]
  125. reweighter.update(newX, newY)
  126. if not(rejectNoise and newY == -1):
  127. classifier.update(newX, newY)
  128. if newY == -1:
  129. print '-- updated with noise'
  130. else:
  131. print '-- noise drawn and rejected'
  132. queriedIdxs[taskIdx,rndInitIdx,step] = orgIdxs[0,chosenIdx]
  133. pred = classifier.test(xTest, True)
  134. confMats[taskIdx,rndInitIdx,step + 1,:,:] = helperFunctions.confusionMatrix(yTest, pred)
  135. knownCls[taskIdx,rndInitIdx,step + 1] = classifier.yUni.shape[1] - (classifier.yUni == -1).any()
  136. xPool = numpy.delete(xPool, (chosenIdx), axis=0)
  137. yPool = numpy.delete(yPool, (chosenIdx), axis=0)
  138. orgIdxs = numpy.delete(orgIdxs, (chosenIdx), axis=1)
  139. if useApproximation:
  140. classifier.clearApprox(chosenIdx)
  141. t1 = time.time()
  142. timeNeeded[taskIdx,rndInitIdx,step] = (t1 - t0)
  143. queriesPast = queriesPast + 1
  144. timePast = timePast + float(t1 - t0)
  145. timePerPass = timePast / float(queriesPast)
  146. timeOva = timePerPass * totalRuns
  147. estTimeLeft = timeOva - timePast
  148. print queriesPast, '/', totalRuns, '- time past:', '%.3f'%(timePast/3600.0), 'h, avg. time per pass:', '%.3f'%timePerPass, 's, est. time left:', '%.3f'%(estTimeLeft/3600.0), 'h, est. time over all:', '%.3f'%(timeOva/3600.0), 'h'
  149. print 'chosenIdx:', chosenIdx, ', task:', taskIdx, ', rndInit:', rndInitIdx, ', step:', step, ', acc:', '%.5f'%helperFunctions.getAvgAcc(confMats[taskIdx,rndInitIdx,step + 1,:,:]), ', knownCls:', int(knownCls[taskIdx,rndInitIdx,step + 1])
  150. sys.stdout.flush()
  151. results = dict(confMats=confMats, queriedIdxs=queriedIdxs, name=identifier, identifier=identifier, knownCls=knownCls, timeNeeded=timeNeeded, lastTaskIdx=taskIdx, lastRndInitIdx=rndInitIdx)
  152. writeAttempt = 0
  153. while True:
  154. try:
  155. scipy.io.savemat(resultsFileName, dict(results=results))
  156. break
  157. except:
  158. writeAttempt = writeAttempt + 1
  159. if writeAttempt >= writeAttemptsNmb:
  160. raise Exception('ERROR: Writing file {} failed {} times!'.format(resultsFileName, writeAttempt))
  161. print ''
  162. print 'WARNING: Writing file {} failed ({} / {}, retry after {} seconds)!'.format(resultsFileName, writeAttempt, writeAttemptsNmb, writeAttemptsDelay)
  163. sys.stdout.flush()
  164. time.sleep(writeAttemptsDelay)
  165. startRndInitIdx = 0
  166. if notificationPath is not None:
  167. helperFunctions.writeNotification(notificationPath, 'status__' + identifier + '__' + socket.gethostname() + '__done')
  168. print 'done'
  169. print 'now', datetime.datetime.strftime(datetime.datetime.now(), '%d.%m.%Y %H:%M:%S')