import pickle
import socket
import datetime
import numpy

###

import sys
import os
sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)),os.pardir))

###

import helperFunctions
import datasetAcquisition

###

if len(sys.argv) != 2:
    raise Exception('No config file given!')

print ''
print ' -- config -- '
print ''

numTasks = helperFunctions.getConfig(sys.argv[1], 'experiment', 'numTasks', None, 'int', True)
numRndInits = helperFunctions.getConfig(sys.argv[1], 'experiment', 'numRndInits', None, 'int', True)
numCls = helperFunctions.getConfig(sys.argv[1], 'experiment', 'numCls', None, 'int', True)
numInitCls = helperFunctions.getConfig(sys.argv[1], 'experiment', 'numInitCls', None, 'int', True)
numInitSamplesPerCls = helperFunctions.getConfig(sys.argv[1], 'experiment', 'numInitSamplesPerCls', None, 'int', True)
numTestSamplesPerCls = helperFunctions.getConfig(sys.argv[1], 'experiment', 'numTestSamplesPerCls', None, 'int', True)

forbiddenCls = helperFunctions.getConfig(sys.argv[1], 'data', 'forbiddenCls', [], 'intList', True)
indicesFileName = helperFunctions.getConfig(sys.argv[1], 'data', 'indicesFileName', None, 'str', True)

print ''
print 'host:', socket.gethostname()
print 'pid:', os.getpid()
print 'now:', datetime.datetime.strftime(datetime.datetime.now(), '%d.%m.%Y %H:%M:%S')
print 'git:', helperFunctions.getGitHash()
print ''
sys.stdout.flush()

###

trainIdxs = list()
testIdxs = list()

###

print''
xTrain, yTrain, xTest, yTest = datasetAcquisition.readData(sys.argv[1])

###

print ''
print ' -- train --'
print ''
uniY = numpy.unique(numpy.asarray(yTrain))

###

for fCLs in forbiddenCls:
    uniY = numpy.delete(uniY, numpy.where(uniY == fCLs), axis=0)

for fCLs in forbiddenCls:
    uniY = numpy.delete(uniY, numpy.where(uniY == -1), axis=0)

###

for taskIdx in range(numTasks):

    initCls = uniY[numpy.random.permutation(len(uniY))[:numInitCls]]
    taskList = list()

    for rndInitIdx in range(numRndInits):

        rndInitList = list()

        for clsIdx in range(len(initCls)):

            clsSamples = numpy.ravel(numpy.where(numpy.ravel(yTrain) == numpy.ravel(initCls[clsIdx])))

            if rndInitIdx == 0:
                print 'cls', initCls[clsIdx], 'with', len(clsSamples), 'samples chosen'
                if len(clsSamples) < numInitSamplesPerCls:
                    #raise Exception('To few samples!')
                    print '>>> to few samples!'
                sys.stdout.flush()

            rndInitList.extend(list(clsSamples[numpy.random.permutation(len(clsSamples))[:numInitSamplesPerCls]]))

        print 'samples gathered for rndinit', rndInitIdx, 'of task', taskIdx, ' -> ', len(rndInitList)

        taskList.append(list(rndInitList))

    trainIdxs.append(list(taskList))

###

print ''
print ' -- test --'
print ''

uniY = numpy.unique(numpy.asarray(yTest))

###

for fCLs in forbiddenCls:
    uniY = numpy.delete(uniY, numpy.where(uniY == fCLs), axis=0)

for fCLs in forbiddenCls:
    uniY = numpy.delete(uniY, numpy.where(uniY == -1), axis=0)

###


for taskIdx in range(numTasks):

    taskList = list()

    for rndInitIdx in range(numRndInits):

        rndInitList = list()

        for clsIdx in range(len(uniY)):

            clsSamples = numpy.ravel(numpy.where(numpy.ravel(yTest) == numpy.ravel(uniY[clsIdx])))

            if taskIdx == 0 and rndInitIdx == 0:
                print 'samples available for cls', uniY[clsIdx], ' -> ', len(clsSamples)
                if len(clsSamples) < numTestSamplesPerCls:
                    #raise Exception('To few samples!')
                    print '>>> to few samples!'
                sys.stdout.flush()

            rndInitList.extend(list(clsSamples[numpy.random.permutation(len(clsSamples))[:numTestSamplesPerCls]]))

        print 'samples gathered for rndinit', rndInitIdx, 'of task', taskIdx, ' -> ', len(rndInitList)

        taskList.append(list(rndInitList))

    testIdxs.append(list(taskList))

###

out = open(indicesFileName, 'w')
pickle.dump({'trainIdxs': trainIdxs, 'testIdxs': testIdxs}, out)
out.close()
print ''
print 'done'