import numpy as np
import pandas as pd
import math
from scipy.signal import medfilt
import matplotlib.pyplot as plt

def standardize(x):
    """
    standardize array, x-mu / sigma (mu: mean, sigma: standard deviation)
    :param x: pandas series
    :return: standardized array or series
    """
    mu = np.mean(x, axis=0)
    sigma = np.std(x, axis=0)

    return (x - mu) / sigma


def smoothing(x):
    df = pd.DataFrame(x)
    df = df.rolling(window=5, center=True).median()
    df = df.rolling(window=5, win_type='gaussian', center=True).mean(std=20)
    return df.to_numpy()



def intervalsToBinary(a, tslength):
    """ Transform list of intervals to an binary array of length tslength,
    where 1 is interval and 0 is no interval.
    E.g. self.idx=[(3,4)], tslength=10 returns [0,0,0,1,1,1,1,0,0,0]"""

    gt = np.zeros(tslength)
    if len(a) > 0:
        for i in a:
            gt[int(i[0]):int(i[0]) + int(i[1])] = 1

    return gt

def binaryToIntervals(array):
    """ Transform binary array, where 1 is interval and 0 is no interval,
     to list of intervals.
     E.g. array=[0,0,0,1,1,1,1,0,0,0] returns [(3,4)]"""

    intervals = []
    start = np.where(array == 1)[0][0]
    current = np.where(array == 1)[0][0]
    count = 1
    for val in np.where(array == 1)[0][1:]:
        if val-1 == current and val != len(array)-1:
            current = val
            count += 1
        elif val-1 == current and val == len(array)-1:
            count += 1
            intervals += [start, count]
        else:
            intervals += [start, count]
            start = val
            current = val
            count = 1
    intervals = np.array(intervals).reshape(-1,2)

    return intervals


def median_filter_intervals(a, kernel):
    tslength = int(np.sum(a[-1]))
    binary = intervalsToBinary(a, tslength)
    filteredintervals = medfilt(binary, kernel_size=kernel)
    filteredintervals = np.maximum.reduce([filteredintervals, binary])
    medfiltidx = binaryToIntervals(filteredintervals)
    lengthmedfilt = np.sum(np.array(medfiltidx)[:,1])
    countmedfilt = len(medfiltidx)
    return medfiltidx



def arguments(X, Y, thresCorrelation, minLenInt,
               maxLenInt, shiftTS, distancemeasure):

    # argss1 and argss2 are used to get all shifts (no-shift, forward and backward) and all minimum length combinations.
    argss1 = [(X[shift:], Y[:-shift], thresCorrelation, minLen, maxLenInt, distancemeasure) if shift != 0
              else (X, Y, thresCorrelation, minLen, maxLenInt, distancemeasure)
              for shift in shiftTS for minLen in minLenInt]
    argss2 = [(X[:-shift], Y[shift:], thresCorrelation, minLen, maxLenInt, distancemeasure)
              for shift in shiftTS for minLen in minLenInt if shift != 0]
    args = argss1 + argss2
    return args


def mutual_information_short(hgram):
    s = float(np.sum(hgram))
    hgram = hgram / s
    px = np.sum(hgram, axis=1)
    py = np.sum(hgram, axis=0)
    px_py = px[:, None] * py[None, :]
    nzs = hgram > 0
    return np.sum(hgram[nzs] * np.log(hgram[nzs] / px_py[nzs]))


def ts_distance(distancemeasure,X,Y,bins):
    """
    Similarity measures that can be called in relevant interval selection
    approach. Either Peasron correlation or Normalized mutual information.
    New similarity measures can be added with other elif statements.

    :param distancemeasure: 'pearson' or 'mutual'
    :param X: Time series
    :param Y: Time series
    :param bins: Bins for mutual information computation

    :return: returns similarity of X and Y
    """
    import warnings
    if distancemeasure == 'pearson':
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            r = np.corrcoef(X, Y)[0, 1]
            # k, _ = pearsonr(X, Y)

    elif distancemeasure == 'mutual':

        #compute best number of bins
        #breaks program sometimes...
        #k = 10
        #sigma = np.corrcoef(X, Y)[0][1]
        #if not np.isnan(sigma) or sigma == 1:
        #    N = len(X)
        #    k = int(np.ceil(0.5 + 0.5 * np.sqrt(
        #        1 + 4 * np.sqrt((6 * N * sigma * sigma) / (
        #            1 - (sigma * sigma))))))

        # Calculate histograms
        concat = np.hstack((X, Y))
        l, h = np.nanmin(concat), np.nanmax(concat)
        hgram = np.histogram2d(X, Y, range=[[l, h], [l, h]], bins=bins)[0]
        r = mutual_information_short(hgram)
        # Normalize mutual information
        r = np.sqrt(1 - np.exp(-2 * r))
    return r

def LongestSet(IntList):
    """
    Calculate longest set of maximal correlated intervals

    :param IntList: 2d list of intervals, e.g. [[0,2], [1,2], [1,4], [4,3],
                    [7,2], [8,2], [4,3]]

    :return: numpy array as list of intervals where sum of length of
    intervals is largest regarding all possible combinations of
    non-overlapping intervals
    """
    if type(IntList) == np.ndarray:
        IntList = [(x[0], x[1]) for x in IntList]

    k = len(IntList)
    if k == 0:
        return 0
    IntList = np.array(sorted(IntList))
    NextInt = {}

    # Compute immediate starting intervals after i ends
    for i, Interval in enumerate(IntList):
        end = Interval[0] + Interval[1]  # Interval[0] is the starting frame and Interval[1] is the duration
        nextidx = IntList[IntList[:, 0] >= end]
        if nextidx.size > 0:
            minstart = np.min(nextidx[:, 0])
            NextInt[i] = np.argwhere(IntList[:, 0] >= minstart)
        else:
            NextInt[i] = None

    # Calculate cumulated sum iterating by end of the list
    CumSum = np.zeros(k)
    SelInts = {}
    for idx, Interval in enumerate(reversed(IntList)):
        i = (k - 1) - idx
        b = NextInt[i]
        if np.all(b != None):
            # print(
            #     CumSum[NextInt[i]] + Interval[1])
            bestfollower = int(NextInt[i][np.argmax(
                CumSum[NextInt[i]] + Interval[1])])
            CumSum[i] = int(Interval[1] + CumSum[bestfollower])
            if Interval[1] + CumSum[bestfollower] >= CumSum[i + 1]:
                SelInts[i] = bestfollower
        else:
            CumSum[i] = Interval[1]
            SelInts[i] = None

    # Loop forward
    Result = np.array([])
    current = np.where(CumSum == CumSum.max())[0][-1]
    while True:
        intval = IntList[current]
        Result = np.append(Result, [intval])
        current = SelInts[current]
        if current not in SelInts:
            break

    Result = Result.reshape(-1, 2)
    return Result


def glamina(X: np.ndarray, Y: np.ndarray, threshold: float, l_min: list[int], l_max: int, col_x: list[int] | None =None, col_y: list[int] | None =None, shifts: list[int] | None =None, distancemeasure: str ='pearson'):
    """
    Finds the relevant intervals among multiple time series based on correlation between them.

    :param X: whole data of subject 1 with columns holding each time series
    :param Y: whole data of subject 2 with columns holding each time series
    :param threshold: threshold value starting from which the correlation is significant
    :param l_min: list of minimum length of continuous relevant intervals to be considered
    :param l_max: maximum length of continuous relevant intervals
    :param col_x: list of selected columns in subject 1. By default, takes in all columns.
    :param col_y: list of selected columns in subject 2 with respect to col_x. By default, takes in all columns.
    :param shifts: list of values with which X is shifted back and forth in time.
    :param distancemeasure: correlation metric. By default, Pearson's correlation

    :return: List of lists containing selected relevant intervals as [starting time, duration].
    """
    # error handling
    if (len(col_x) != len(col_y)):
        raise ValueError("Missing respective column indices in function call. There is a shape mismatch.")

    if shifts ==  None:
        shifts = [0]

    # to select the columns 
    if (col_x != None and col_y != None):
        X = X[:, col_x]
        Y = Y[:, col_y]

    interval = []
    Result = []
    laminaresult = []
    x = []
    y = []

    # loop for each minimum interval length
    for m in l_min:
        d1 = [[X[s:, :], Y[:-s, :]] if s != 0 else [X, Y] for s in shifts]
        d2 = [[X[:-s, :], Y[s:, :]] for s in shifts if s != 0]
        d = d1 + d2
        for X, Y in d:
            indexlength = min(X.shape[0], Y.shape[0])
            IntMat = np.zeros([indexlength, l_max - m + 1], dtype=np.float64)

            # calculate correlation for minimum interval length
            winlen = m
            n_channel = X.shape[1]

            # j = 0
            # loop to average correlation for minimum interval length
            for col in range(0, n_channel):
                x = X[:, col]
                y = Y[:, col]
                for i in range(indexlength - m + 1):
                    r = ts_distance(distancemeasure, x[i:i + winlen],
                                    y[i:i + winlen], bins=10)
                    if not math.isnan(r):
                        IntMat[i, 0] += r
                # j +=1
            IntMat = IntMat/n_channel


            # loop to average correlation for intervals above minimum interval
            for col in range(0, n_channel):
                for winlen in range(m + 1, l_max + 1):
                    for i in range(indexlength - winlen + 1):
                        if IntMat[i, winlen - 1 - m] >= threshold and \
                                IntMat[i + 1, winlen - 1 - m] >= threshold:
                            r = ts_distance(distancemeasure, x[i:i + winlen],
                                            y[i:i + winlen], bins=10)
                            if not math.isnan(r):
                                IntMat[i, winlen - m] += r


            IntMat = IntMat[:, 1:] / n_channel

            CorInts = np.where(IntMat >= threshold)
            del IntMat
            CorInts = list(zip(CorInts[0], CorInts[1] + m))

            # check if correlated intervals are maximal
            ResultInt = []
            for i, lenWin in CorInts:
                if ((i - 1, lenWin + 1) not in CorInts) and ((i, lenWin + 1) not in CorInts):
                    ResultInt += [(i, lenWin)]

            del CorInts

            interval.append(ResultInt)

    if len(interval) > 0:
        Result = [int for intlist in interval for int in intlist]  # flatten the array

    if len(Result) > 0:
        laminaresult = LongestSet(Result)  # find the longest set of intervals from Result

    return laminaresult


def gen_series(lenTS=1000, start=50, end=300, amplitude=50, noise=1, seed=10):
    # step_signal = np.zeros(lenTS)
    # # step_signal[start:end+1] = amplitude
    # step_signal[start:end+1] = np.linspace(5, amplitude, end-start+1)
    step_signal = np.linspace(5, amplitude, lenTS)
    step_signal[:start] = step_signal[end+1:] = 0
    if noise !=0:
        np.random.seed(seed)
        noise = noise * np.random.randn(lenTS)
    s1 = step_signal + noise
    return s1


if __name__ == "__main__":

    # np.random.seed(10)

    # shifts = [0]
    # threshold = 0.99
    # minLenInts = [70]
    # maxLenInts = 200
    # distancemeasure = 'pearson'
    #
    # medfiltlength = 51
    #
    # path = '/home/datasets4/fNIRS/fNIRSData/MI-003/Sub1/Sub1_preprocessed.csv'
    # df = pd.read_csv(path)
    # data_1 = df.to_numpy()
    #
    # path = '/home/datasets4/fNIRS/fNIRSData/MI-003/Sub2/Sub2_preprocessed.csv'
    # df = pd.read_csv(path)
    # data_2 = df.to_numpy()
    #
    # data1 = standardize(data_1)
    # data2 = standardize(data_2)
    #
    # data1 = smoothing(data1)
    # data2 = smoothing(data2)

    shifts = [0]  # shifts that must be considered
    threshold = 0.87 # (imp) depends on the correlation values
    minLenInts = [50]  # (imp) depends on the minimum continuous interval that must be detected
    maxLenInts = 250  # (less relevant) depends on the max continuous interval
    distancemeasure = 'pearson'

    medfiltlength = 51

    S1 = gen_series(start=100, end=300, amplitude=20, noise=1, seed=10)
    R1 = gen_series(start=120, end=320, amplitude=20, noise=1, seed=11)
    S2 = gen_series(start=90, end=280, amplitude=20, noise=1, seed=12)
    R2 = gen_series(start=200, end=260, amplitude=20, noise=1, seed=13)

    data1 = np.column_stack((S1, S2))
    data2 = np.column_stack((R1, R2))

    result = glamina(data1, data2, threshold=threshold, l_min=minLenInts, l_max=maxLenInts, col_x=[0,1], col_y=[0,1], shifts=shifts, distancemeasure=distancemeasure)
    Result = median_filter_intervals(result, kernel=medfiltlength)

    # import matplotlib.pyplot as plt
    # plt.subplot(5, 1, 1)
    # plt.plot(np.arange(0, len(S1)), np.array(S1), label='x1', color='k')
    # plt.legend(loc="upper left")
    # plt.subplot(5, 1, 2)
    # plt.plot(np.arange(0, len(S2)), np.array(S2), label='x2', color='b')
    # plt.legend(loc="upper left")
    # plt.subplot(5, 1, 3)
    # plt.plot(np.arange(0, len(R1)), np.array(R1), label='y1', color='k')
    # plt.legend(loc="upper left")
    # plt.subplot(5, 1, 4)
    # plt.plot(np.arange(0, len(R2)), np.array(R2), label='y2', color='b')
    # plt.legend(loc="upper left")
    # plt.subplot(5, 1, 5)
    # plt.plot(np.arange(0, len(R2)), np.array(gen_series(start=200, end=260, amplitude=30, noise=0)),
    #          label='ground truth', color='g')
    # if len(Result) > 0:
    #     for idx, i in enumerate(Result):
    #         plt.axvspan(i[0], (i[0] + i[1] - 1), facecolor='r', alpha=0.5,
    #                     label=f'detection\nth:{threshold}\nmin:{minLenInts}' if idx == 0 else "")
    # plt.legend(loc="upper left")
    # plt.show()

    # S1 = data1[:, 6]
    # S2 = data1[:, 7]
    # S3 = data1[:, 8]
    # R1 = data2[:, 6]
    # R2 = data2[:, 7]
    # R3 = data2[:, 8]
    #
    # S4 = data1[:, 12]
    # R4 = data2[:, 12]
    #
    # plt.subplot(7, 1, 1)
    # plt.plot(np.arange(0, len(S1)), np.array(S1), label='x1', color='k')  # color = bgrcmykw
    # plt.legend(loc="upper left")
    # plt.subplot(7, 1, 2)
    # plt.plot(np.arange(0, len(S2)), np.array(S2), label='x2', color='b')
    # plt.legend(loc="upper left")
    # plt.subplot(7, 1, 3)
    # plt.plot(np.arange(0, len(S3)), np.array(S3), label='x3', color='m')
    # plt.legend(loc="upper left")
    # plt.subplot(7, 1, 4)
    # plt.plot(np.arange(0, len(R1)), np.array(R1), label='y1', color='k')
    # plt.legend(loc="upper left")
    # plt.subplot(7, 1, 5)
    # plt.plot(np.arange(0, len(R2)), np.array(R2), label='y2', color='b')
    # plt.legend(loc="upper left")
    # plt.subplot(7, 1, 6)
    # plt.plot(np.arange(0, len(R3)), np.array(R3), label='y3', color='m')
    # plt.legend(loc="upper left")
    #
    # path = '/home/valapil/Project/fnirs_file/annotations/MI-003_clean.csv'
    # df = pd.read_csv(path)
    # truth = df.to_numpy()
    # ind = np.where(truth[:, 1] != 10)[0]
    # sig = np.zeros(data1.shape[0])
    # a = truth[ind, :].astype(np.int16)
    # for s, e in a[:, [0, 2]]:
    #     sig[s:e + 1] = 1
    # plt.subplot(7, 1, 7)
    # plt.plot(np.arange(0, len(sig)), sig, label='ground truth', color='g')
    #
    # if len(Result) > 0:
    #     for idx, i in enumerate(Result):
    #         plt.axvspan(i[0], (i[0] + i[1] - 1), facecolor='r', alpha=0.5,
    #                     label=f'detection\nth:{threshold}\nmin:{minLenInts}' if idx == 0 else "")
    # plt.legend(loc="upper left")

    plt.plot(np.arange(0, len(S1)), np.array(S1), label='x1', color='k')
    if len(Result) > 0:
        for idx, i in enumerate(Result):
            plt.axvspan(i[0], (i[0] + i[1] - 1), facecolor='r', alpha=0.5,
                        label=f'detection\nth:{threshold}\nmin:{minLenInts}' if idx == 0 else "")
    plt.legend(loc="upper left")
    print(Result)
    plt.show()