PrecomputeExperimentalSetup.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import pickle
  2. import socket
  3. import datetime
  4. import numpy
  5. ###
  6. import sys
  7. import os
  8. sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)),os.pardir))
  9. ###
  10. import helperFunctions
  11. import datasetAcquisition
  12. ###
  13. if len(sys.argv) != 2:
  14. raise Exception('No config file given!')
  15. print ''
  16. print ' -- config -- '
  17. print ''
  18. defaultFname = os.path.join(os.path.dirname(sys.argv[1]),os.pardir,'setup.cfg')
  19. if not os.path.isfile(defaultFname):
  20. defaultFname = sys.argv[1]
  21. expSetup = helperFunctions.getConfig(sys.argv[1], 'experiment', 'extExpSetup', defaultFname, 'str', True)
  22. numRndInits = helperFunctions.getConfig(expSetup, 'experiment', 'numRndInits', None, 'int', True)
  23. numInitSamples = helperFunctions.getConfig(expSetup, 'experiment', 'numInitSamples', None, 'int', True)
  24. numTestSamples = helperFunctions.getConfig(expSetup, 'experiment', 'numTestSamples', None, 'int', True)
  25. indicesFileName = helperFunctions.getConfig(expSetup, 'experiment', 'indicesFileName', None, 'str', True)
  26. print ''
  27. print 'host:', socket.gethostname()
  28. print 'pid:', os.getpid()
  29. print 'now:', datetime.datetime.strftime(datetime.datetime.now(), '%d.%m.%Y %H:%M:%S')
  30. print 'git:', helperFunctions.getGitHash()
  31. print ''
  32. sys.stdout.flush()
  33. ###
  34. #trainIdxs = numpy.empty((numTasks,numRndInits,numInitCls*numInitSamplesPerCls), dtype=numpy.int)
  35. #testIdxs = numpy.empty((numTasks,numRndInits,numCls*numTestSamplesPerCls), dtype=numpy.int)
  36. trainIdxs = list()
  37. testIdxs = list()
  38. ###
  39. x,y = datasetAcquisition.readData(expSetup)
  40. for rndIdx in range(numRndInits):
  41. rndIdxs = numpy.random.permutation(y.shape[0])
  42. trainIdxs.append(rndIdxs[:numInitSamples])
  43. testIdxs.append(rndIdxs[numInitSamples+1:numInitSamples+numTestSamples+1])
  44. ###
  45. out = open(indicesFileName, 'w')
  46. pickle.dump({'trainIdxs': trainIdxs, 'testIdxs': testIdxs}, out)
  47. out.close()
  48. print ''
  49. print 'done'