import pickle import socket import datetime import numpy ### import sys import os sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)),os.pardir)) ### import helperFunctions import datasetAcquisition ### if len(sys.argv) != 2: raise Exception('No config file given!') print '' print ' -- config -- ' print '' numTasks = helperFunctions.getConfig(sys.argv[1], 'experiment', 'numTasks', None, 'int', True) numRndInits = helperFunctions.getConfig(sys.argv[1], 'experiment', 'numRndInits', None, 'int', True) numCls = helperFunctions.getConfig(sys.argv[1], 'experiment', 'numCls', None, 'int', True) numInitCls = helperFunctions.getConfig(sys.argv[1], 'experiment', 'numInitCls', None, 'int', True) numInitSamplesPerCls = helperFunctions.getConfig(sys.argv[1], 'experiment', 'numInitSamplesPerCls', None, 'int', True) numTestSamplesPerCls = helperFunctions.getConfig(sys.argv[1], 'experiment', 'numTestSamplesPerCls', None, 'int', True) forbiddenCls = helperFunctions.getConfig(sys.argv[1], 'data', 'forbiddenCls', [], 'intList', True) indicesFileName = helperFunctions.getConfig(sys.argv[1], 'data', 'indicesFileName', None, 'str', True) print '' print 'host:', socket.gethostname() print 'pid:', os.getpid() print 'now:', datetime.datetime.strftime(datetime.datetime.now(), '%d.%m.%Y %H:%M:%S') print 'git:', helperFunctions.getGitHash() print '' sys.stdout.flush() ### trainIdxs = list() testIdxs = list() ### print'' xTrain, yTrain, xTest, yTest = datasetAcquisition.readData(sys.argv[1]) ### print '' print ' -- train --' print '' uniY = numpy.unique(numpy.asarray(yTrain)) ### for fCLs in forbiddenCls: uniY = numpy.delete(uniY, numpy.where(uniY == fCLs), axis=0) for fCLs in forbiddenCls: uniY = numpy.delete(uniY, numpy.where(uniY == -1), axis=0) ### for taskIdx in range(numTasks): initCls = uniY[numpy.random.permutation(len(uniY))[:numInitCls]] taskList = list() for rndInitIdx in range(numRndInits): rndInitList = list() for clsIdx in range(len(initCls)): clsSamples = numpy.ravel(numpy.where(numpy.ravel(yTrain) == numpy.ravel(initCls[clsIdx]))) if rndInitIdx == 0: print 'cls', initCls[clsIdx], 'with', len(clsSamples), 'samples chosen' if len(clsSamples) < numInitSamplesPerCls: #raise Exception('To few samples!') print '>>> to few samples!' sys.stdout.flush() rndInitList.extend(list(clsSamples[numpy.random.permutation(len(clsSamples))[:numInitSamplesPerCls]])) print 'samples gathered for rndinit', rndInitIdx, 'of task', taskIdx, ' -> ', len(rndInitList) taskList.append(list(rndInitList)) trainIdxs.append(list(taskList)) ### print '' print ' -- test --' print '' uniY = numpy.unique(numpy.asarray(yTest)) ### for fCLs in forbiddenCls: uniY = numpy.delete(uniY, numpy.where(uniY == fCLs), axis=0) for fCLs in forbiddenCls: uniY = numpy.delete(uniY, numpy.where(uniY == -1), axis=0) ### for taskIdx in range(numTasks): taskList = list() for rndInitIdx in range(numRndInits): rndInitList = list() for clsIdx in range(len(uniY)): clsSamples = numpy.ravel(numpy.where(numpy.ravel(yTest) == numpy.ravel(uniY[clsIdx]))) if taskIdx == 0 and rndInitIdx == 0: print 'samples available for cls', uniY[clsIdx], ' -> ', len(clsSamples) if len(clsSamples) < numTestSamplesPerCls: #raise Exception('To few samples!') print '>>> to few samples!' sys.stdout.flush() rndInitList.extend(list(clsSamples[numpy.random.permutation(len(clsSamples))[:numTestSamplesPerCls]])) print 'samples gathered for rndinit', rndInitIdx, 'of task', taskIdx, ' -> ', len(rndInitList) taskList.append(list(rndInitList)) testIdxs.append(list(taskList)) ### out = open(indicesFileName, 'w') pickle.dump({'trainIdxs': trainIdxs, 'testIdxs': testIdxs}, out) out.close() print '' print 'done'