123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153 |
- import pickle
- import socket
- import datetime
- import numpy
- ###
- import sys
- import os
- sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)),os.pardir))
- ###
- import helperFunctions
- import datasetAcquisition
- ###
- if len(sys.argv) != 2:
- raise Exception('No config file given!')
- print ''
- print ' -- config -- '
- print ''
- numTasks = helperFunctions.getConfig(sys.argv[1], 'experiment', 'numTasks', None, 'int', True)
- numRndInits = helperFunctions.getConfig(sys.argv[1], 'experiment', 'numRndInits', None, 'int', True)
- numCls = helperFunctions.getConfig(sys.argv[1], 'experiment', 'numCls', None, 'int', True)
- numInitCls = helperFunctions.getConfig(sys.argv[1], 'experiment', 'numInitCls', None, 'int', True)
- numInitSamplesPerCls = helperFunctions.getConfig(sys.argv[1], 'experiment', 'numInitSamplesPerCls', None, 'int', True)
- numTestSamplesPerCls = helperFunctions.getConfig(sys.argv[1], 'experiment', 'numTestSamplesPerCls', None, 'int', True)
- forbiddenCls = helperFunctions.getConfig(sys.argv[1], 'data', 'forbiddenCls', [], 'intList', True)
- indicesFileName = helperFunctions.getConfig(sys.argv[1], 'data', 'indicesFileName', None, 'str', True)
- print ''
- print 'host:', socket.gethostname()
- print 'pid:', os.getpid()
- print 'now:', datetime.datetime.strftime(datetime.datetime.now(), '%d.%m.%Y %H:%M:%S')
- print 'git:', helperFunctions.getGitHash()
- print ''
- sys.stdout.flush()
- ###
- trainIdxs = list()
- testIdxs = list()
- ###
- print''
- xTrain, yTrain, xTest, yTest = datasetAcquisition.readData(sys.argv[1])
- ###
- print ''
- print ' -- train --'
- print ''
- uniY = numpy.unique(numpy.asarray(yTrain))
- ###
- for fCLs in forbiddenCls:
- uniY = numpy.delete(uniY, numpy.where(uniY == fCLs), axis=0)
- for fCLs in forbiddenCls:
- uniY = numpy.delete(uniY, numpy.where(uniY == -1), axis=0)
- ###
- for taskIdx in range(numTasks):
- initCls = uniY[numpy.random.permutation(len(uniY))[:numInitCls]]
- taskList = list()
- for rndInitIdx in range(numRndInits):
- rndInitList = list()
- for clsIdx in range(len(initCls)):
- clsSamples = numpy.ravel(numpy.where(numpy.ravel(yTrain) == numpy.ravel(initCls[clsIdx])))
- if rndInitIdx == 0:
- print 'cls', initCls[clsIdx], 'with', len(clsSamples), 'samples chosen'
- if len(clsSamples) < numInitSamplesPerCls:
- #raise Exception('To few samples!')
- print '>>> to few samples!'
- sys.stdout.flush()
- rndInitList.extend(list(clsSamples[numpy.random.permutation(len(clsSamples))[:numInitSamplesPerCls]]))
- print 'samples gathered for rndinit', rndInitIdx, 'of task', taskIdx, ' -> ', len(rndInitList)
- taskList.append(list(rndInitList))
- trainIdxs.append(list(taskList))
- ###
- print ''
- print ' -- test --'
- print ''
- uniY = numpy.unique(numpy.asarray(yTest))
- ###
- for fCLs in forbiddenCls:
- uniY = numpy.delete(uniY, numpy.where(uniY == fCLs), axis=0)
- for fCLs in forbiddenCls:
- uniY = numpy.delete(uniY, numpy.where(uniY == -1), axis=0)
- ###
- for taskIdx in range(numTasks):
- taskList = list()
- for rndInitIdx in range(numRndInits):
- rndInitList = list()
- for clsIdx in range(len(uniY)):
- clsSamples = numpy.ravel(numpy.where(numpy.ravel(yTest) == numpy.ravel(uniY[clsIdx])))
- if taskIdx == 0 and rndInitIdx == 0:
- print 'samples available for cls', uniY[clsIdx], ' -> ', len(clsSamples)
- if len(clsSamples) < numTestSamplesPerCls:
- #raise Exception('To few samples!')
- print '>>> to few samples!'
- sys.stdout.flush()
- rndInitList.extend(list(clsSamples[numpy.random.permutation(len(clsSamples))[:numTestSamplesPerCls]]))
- print 'samples gathered for rndinit', rndInitIdx, 'of task', taskIdx, ' -> ', len(rndInitList)
- taskList.append(list(rndInitList))
- testIdxs.append(list(taskList))
- ###
- out = open(indicesFileName, 'w')
- pickle.dump({'trainIdxs': trainIdxs, 'testIdxs': testIdxs}, out)
- out.close()
- print ''
- print 'done'
|