PrecomputeExperimentalSetup.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. import pickle
  2. import socket
  3. import datetime
  4. import numpy
  5. ###
  6. import sys
  7. import os
  8. sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)),os.pardir))
  9. ###
  10. import helperFunctions
  11. import datasetAcquisition
  12. ###
  13. if len(sys.argv) != 2:
  14. raise Exception('No config file given!')
  15. print ''
  16. print ' -- config -- '
  17. print ''
  18. numTasks = helperFunctions.getConfig(sys.argv[1], 'experiment', 'numTasks', None, 'int', True)
  19. numRndInits = helperFunctions.getConfig(sys.argv[1], 'experiment', 'numRndInits', None, 'int', True)
  20. numCls = helperFunctions.getConfig(sys.argv[1], 'experiment', 'numCls', None, 'int', True)
  21. numInitCls = helperFunctions.getConfig(sys.argv[1], 'experiment', 'numInitCls', None, 'int', True)
  22. numInitSamplesPerCls = helperFunctions.getConfig(sys.argv[1], 'experiment', 'numInitSamplesPerCls', None, 'int', True)
  23. numTestSamplesPerCls = helperFunctions.getConfig(sys.argv[1], 'experiment', 'numTestSamplesPerCls', None, 'int', True)
  24. forbiddenCls = helperFunctions.getConfig(sys.argv[1], 'data', 'forbiddenCls', [], 'intList', True)
  25. indicesFileName = helperFunctions.getConfig(sys.argv[1], 'data', 'indicesFileName', None, 'str', True)
  26. print ''
  27. print 'host:', socket.gethostname()
  28. print 'pid:', os.getpid()
  29. print 'now:', datetime.datetime.strftime(datetime.datetime.now(), '%d.%m.%Y %H:%M:%S')
  30. print 'git:', helperFunctions.getGitHash()
  31. print ''
  32. sys.stdout.flush()
  33. ###
  34. trainIdxs = list()
  35. testIdxs = list()
  36. ###
  37. print''
  38. xTrain, yTrain, xTest, yTest = datasetAcquisition.readData(sys.argv[1])
  39. ###
  40. print ''
  41. print ' -- train --'
  42. print ''
  43. uniY = numpy.unique(numpy.asarray(yTrain))
  44. ###
  45. for fCLs in forbiddenCls:
  46. uniY = numpy.delete(uniY, numpy.where(uniY == fCLs), axis=0)
  47. for fCLs in forbiddenCls:
  48. uniY = numpy.delete(uniY, numpy.where(uniY == -1), axis=0)
  49. ###
  50. for taskIdx in range(numTasks):
  51. initCls = uniY[numpy.random.permutation(len(uniY))[:numInitCls]]
  52. taskList = list()
  53. for rndInitIdx in range(numRndInits):
  54. rndInitList = list()
  55. for clsIdx in range(len(initCls)):
  56. clsSamples = numpy.ravel(numpy.where(numpy.ravel(yTrain) == numpy.ravel(initCls[clsIdx])))
  57. if rndInitIdx == 0:
  58. print 'cls', initCls[clsIdx], 'with', len(clsSamples), 'samples chosen'
  59. if len(clsSamples) < numInitSamplesPerCls:
  60. #raise Exception('To few samples!')
  61. print '>>> to few samples!'
  62. sys.stdout.flush()
  63. rndInitList.extend(list(clsSamples[numpy.random.permutation(len(clsSamples))[:numInitSamplesPerCls]]))
  64. print 'samples gathered for rndinit', rndInitIdx, 'of task', taskIdx, ' -> ', len(rndInitList)
  65. taskList.append(list(rndInitList))
  66. trainIdxs.append(list(taskList))
  67. ###
  68. print ''
  69. print ' -- test --'
  70. print ''
  71. uniY = numpy.unique(numpy.asarray(yTest))
  72. ###
  73. for fCLs in forbiddenCls:
  74. uniY = numpy.delete(uniY, numpy.where(uniY == fCLs), axis=0)
  75. for fCLs in forbiddenCls:
  76. uniY = numpy.delete(uniY, numpy.where(uniY == -1), axis=0)
  77. ###
  78. for taskIdx in range(numTasks):
  79. taskList = list()
  80. for rndInitIdx in range(numRndInits):
  81. rndInitList = list()
  82. for clsIdx in range(len(uniY)):
  83. clsSamples = numpy.ravel(numpy.where(numpy.ravel(yTest) == numpy.ravel(uniY[clsIdx])))
  84. if taskIdx == 0 and rndInitIdx == 0:
  85. print 'samples available for cls', uniY[clsIdx], ' -> ', len(clsSamples)
  86. if len(clsSamples) < numTestSamplesPerCls:
  87. #raise Exception('To few samples!')
  88. print '>>> to few samples!'
  89. sys.stdout.flush()
  90. rndInitList.extend(list(clsSamples[numpy.random.permutation(len(clsSamples))[:numTestSamplesPerCls]]))
  91. print 'samples gathered for rndinit', rndInitIdx, 'of task', taskIdx, ' -> ', len(rndInitList)
  92. taskList.append(list(rndInitList))
  93. testIdxs.append(list(taskList))
  94. ###
  95. out = open(indicesFileName, 'w')
  96. pickle.dump({'trainIdxs': trainIdxs, 'testIdxs': testIdxs}, out)
  97. out.close()
  98. print ''
  99. print 'done'