Kaynağa Gözat

Add glamina module

valapil 2 yıl önce
ebeveyn
işleme
f22747a1bb
1 değiştirilmiş dosya ile 435 ekleme ve 0 silme
  1. 435 0
      glamina

+ 435 - 0
glamina

@@ -0,0 +1,435 @@
+import numpy as np
+import pandas as pd
+import math
+from scipy.signal import medfilt
+import matplotlib.pyplot as plt
+
+def standardize(x):
+    """
+    standardize array, x-mu / sigma (mu: mean, sigma: standard deviation)
+    :param x: pandas series
+    :return: standardized array or series
+    """
+    mu = np.mean(x, axis=0)
+    sigma = np.std(x, axis=0)
+
+    return (x - mu) / sigma
+
+
+def smoothing(x):
+    df = pd.DataFrame(x)
+    df = df.rolling(window=5, center=True).median()
+    df = df.rolling(window=5, win_type='gaussian', center=True).mean(std=20)
+    return df.to_numpy()
+
+
+
+def intervalsToBinary(a, tslength):
+    """ Transform list of intervals to an binary array of length tslength,
+    where 1 is interval and 0 is no interval.
+    E.g. self.idx=[(3,4)], tslength=10 returns [0,0,0,1,1,1,1,0,0,0]"""
+
+    gt = np.zeros(tslength)
+    if len(a) > 0:
+        for i in a:
+            gt[int(i[0]):int(i[0]) + int(i[1])] = 1
+
+    return gt
+
+def binaryToIntervals(array):
+    """ Transform binary array, where 1 is interval and 0 is no interval,
+     to list of intervals.
+     E.g. array=[0,0,0,1,1,1,1,0,0,0] returns [(3,4)]"""
+
+    intervals = []
+    start = np.where(array == 1)[0][0]
+    current = np.where(array == 1)[0][0]
+    count = 1
+    for val in np.where(array == 1)[0][1:]:
+        if val-1 == current and val != len(array)-1:
+            current = val
+            count += 1
+        elif val-1 == current and val == len(array)-1:
+            count += 1
+            intervals += [start, count]
+        else:
+            intervals += [start, count]
+            start = val
+            current = val
+            count = 1
+    intervals = np.array(intervals).reshape(-1,2)
+
+    return intervals
+
+
+def median_filter_intervals(a, kernel):
+    tslength = int(np.sum(a[-1]))
+    binary = intervalsToBinary(a, tslength)
+    filteredintervals = medfilt(binary, kernel_size=kernel)
+    filteredintervals = np.maximum.reduce([filteredintervals, binary])
+    medfiltidx = binaryToIntervals(filteredintervals)
+    lengthmedfilt = np.sum(np.array(medfiltidx)[:,1])
+    countmedfilt = len(medfiltidx)
+    return medfiltidx
+
+
+
+def arguments(X, Y, thresCorrelation, minLenInt,
+               maxLenInt, shiftTS, distancemeasure):
+
+    # argss1 and argss2 are used to get all shifts (no-shift, forward and backward) and all minimum length combinations.
+    argss1 = [(X[shift:], Y[:-shift], thresCorrelation, minLen, maxLenInt, distancemeasure) if shift != 0
+              else (X, Y, thresCorrelation, minLen, maxLenInt, distancemeasure)
+              for shift in shiftTS for minLen in minLenInt]
+    argss2 = [(X[:-shift], Y[shift:], thresCorrelation, minLen, maxLenInt, distancemeasure)
+              for shift in shiftTS for minLen in minLenInt if shift != 0]
+    args = argss1 + argss2
+    return args
+
+
+def mutual_information_short(hgram):
+    s = float(np.sum(hgram))
+    hgram = hgram / s
+    px = np.sum(hgram, axis=1)
+    py = np.sum(hgram, axis=0)
+    px_py = px[:, None] * py[None, :]
+    nzs = hgram > 0
+    return np.sum(hgram[nzs] * np.log(hgram[nzs] / px_py[nzs]))
+
+
+def ts_distance(distancemeasure,X,Y,bins):
+    """
+    Similarity measures that can be called in relevant interval selection
+    approach. Either Peasron correlation or Normalized mutual information.
+    New similarity measures can be added with other elif statements.
+
+    :param distancemeasure: 'pearson' or 'mutual'
+    :param X: Time series
+    :param Y: Time series
+    :param bins: Bins for mutual information computation
+
+    :return: returns similarity of X and Y
+    """
+    import warnings
+    if distancemeasure == 'pearson':
+        with warnings.catch_warnings():
+            warnings.simplefilter("ignore")
+            r = np.corrcoef(X, Y)[0, 1]
+            # k, _ = pearsonr(X, Y)
+
+    elif distancemeasure == 'mutual':
+
+        #compute best number of bins
+        #breaks program sometimes...
+        #k = 10
+        #sigma = np.corrcoef(X, Y)[0][1]
+        #if not np.isnan(sigma) or sigma == 1:
+        #    N = len(X)
+        #    k = int(np.ceil(0.5 + 0.5 * np.sqrt(
+        #        1 + 4 * np.sqrt((6 * N * sigma * sigma) / (
+        #            1 - (sigma * sigma))))))
+
+        # Calculate histograms
+        concat = np.hstack((X, Y))
+        l, h = np.nanmin(concat), np.nanmax(concat)
+        hgram = np.histogram2d(X, Y, range=[[l, h], [l, h]], bins=bins)[0]
+        r = mutual_information_short(hgram)
+        # Normalize mutual information
+        r = np.sqrt(1 - np.exp(-2 * r))
+    return r
+
+def LongestSet(IntList):
+    """
+    Calculate longest set of maximal correlated intervals
+
+    :param IntList: 2d list of intervals, e.g. [[0,2], [1,2], [1,4], [4,3],
+                    [7,2], [8,2], [4,3]]
+
+    :return: numpy array as list of intervals where sum of length of
+    intervals is largest regarding all possible combinations of
+    non-overlapping intervals
+    """
+    if type(IntList) == np.ndarray:
+        IntList = [(x[0], x[1]) for x in IntList]
+
+    k = len(IntList)
+    if k == 0:
+        return 0
+    IntList = np.array(sorted(IntList))
+    NextInt = {}
+
+    # Compute immediate starting intervals after i ends
+    for i, Interval in enumerate(IntList):
+        end = Interval[0] + Interval[1]  # Interval[0] is the starting frame and Interval[1] is the duration
+        nextidx = IntList[IntList[:, 0] >= end]
+        if nextidx.size > 0:
+            minstart = np.min(nextidx[:, 0])
+            NextInt[i] = np.argwhere(IntList[:, 0] >= minstart)
+        else:
+            NextInt[i] = None
+
+    # Calculate cumulated sum iterating by end of the list
+    CumSum = np.zeros(k)
+    SelInts = {}
+    for idx, Interval in enumerate(reversed(IntList)):
+        i = (k - 1) - idx
+        b = NextInt[i]
+        if np.all(b != None):
+            # print(
+            #     CumSum[NextInt[i]] + Interval[1])
+            bestfollower = int(NextInt[i][np.argmax(
+                CumSum[NextInt[i]] + Interval[1])])
+            CumSum[i] = int(Interval[1] + CumSum[bestfollower])
+            if Interval[1] + CumSum[bestfollower] >= CumSum[i + 1]:
+                SelInts[i] = bestfollower
+        else:
+            CumSum[i] = Interval[1]
+            SelInts[i] = None
+
+    # Loop forward
+    Result = np.array([])
+    current = np.where(CumSum == CumSum.max())[0][-1]
+    while True:
+        intval = IntList[current]
+        Result = np.append(Result, [intval])
+        current = SelInts[current]
+        if current not in SelInts:
+            break
+
+    Result = Result.reshape(-1, 2)
+    return Result
+
+
+def glamina(X: np.ndarray, Y: np.ndarray, threshold: float, l_min: list[int], l_max: int, col_x: list[int] | None =None, col_y: list[int] | None =None, shifts: list[int] | None =None, distancemeasure: str ='pearson'):
+    """
+    Finds the relevant intervals among multiple time series based on correlation between them.
+
+    :param X: whole data of subject 1 with columns holding each time series
+    :param Y: whole data of subject 2 with columns holding each time series
+    :param threshold: threshold value starting from which the correlation is significant
+    :param l_min: list of minimum length of continuous relevant intervals to be considered
+    :param l_max: maximum length of continuous relevant intervals
+    :param col_x: list of selected columns in subject 1. By default, takes in all columns.
+    :param col_y: list of selected columns in subject 2 with respect to col_x. By default, takes in all columns.
+    :param shifts: list of values with which X is shifted back and forth in time.
+    :param distancemeasure: correlation metric. By default, Pearson's correlation
+
+    :return: List of lists containing selected relevant intervals as [starting time, duration].
+    """
+    # error handling
+    if (len(col_x) != len(col_y)):
+        raise ValueError("Missing respective column indices in function call. There is a shape mismatch.")
+
+    if shifts ==  None:
+        shifts = [0]
+
+    # to select the columns 
+    if (col_x != None and col_y != None):
+        X = X[:, col_x]
+        Y = Y[:, col_y]
+
+    interval = []
+    Result = []
+    laminaresult = []
+    x = []
+    y = []
+
+    # loop for each minimum interval length
+    for m in l_min:
+        d1 = [[X[s:, :], Y[:-s, :]] if s != 0 else [X, Y] for s in shifts]
+        d2 = [[X[:-s, :], Y[s:, :]] for s in shifts if s != 0]
+        d = d1 + d2
+        for X, Y in d:
+            indexlength = min(X.shape[0], Y.shape[0])
+            IntMat = np.zeros([indexlength, l_max - m + 1], dtype=np.float64)
+
+            # calculate correlation for minimum interval length
+            winlen = m
+            n_channel = X.shape[1]
+
+            # j = 0
+            # loop to average correlation for minimum interval length
+            for col in range(0, n_channel):
+                x = X[:, col]
+                y = Y[:, col]
+                for i in range(indexlength - m + 1):
+                    r = ts_distance(distancemeasure, x[i:i + winlen],
+                                    y[i:i + winlen], bins=10)
+                    if not math.isnan(r):
+                        IntMat[i, 0] += r
+                # j +=1
+            IntMat = IntMat/n_channel
+
+
+            # loop to average correlation for intervals above minimum interval
+            for col in range(0, n_channel):
+                for winlen in range(m + 1, l_max + 1):
+                    for i in range(indexlength - winlen + 1):
+                        if IntMat[i, winlen - 1 - m] >= threshold and \
+                                IntMat[i + 1, winlen - 1 - m] >= threshold:
+                            r = ts_distance(distancemeasure, x[i:i + winlen],
+                                            y[i:i + winlen], bins=10)
+                            if not math.isnan(r):
+                                IntMat[i, winlen - m] += r
+
+
+            IntMat = IntMat[:, 1:] / n_channel
+
+            CorInts = np.where(IntMat >= threshold)
+            del IntMat
+            CorInts = list(zip(CorInts[0], CorInts[1] + m))
+
+            # check if correlated intervals are maximal
+            ResultInt = []
+            for i, lenWin in CorInts:
+                if ((i - 1, lenWin + 1) not in CorInts) and ((i, lenWin + 1) not in CorInts):
+                    ResultInt += [(i, lenWin)]
+
+            del CorInts
+
+            interval.append(ResultInt)
+
+    if len(interval) > 0:
+        Result = [int for intlist in interval for int in intlist]  # flatten the array
+
+    if len(Result) > 0:
+        laminaresult = LongestSet(Result)  # find the longest set of intervals from Result
+
+    return laminaresult
+
+
+def gen_series(lenTS=1000, start=50, end=300, amplitude=50, noise=1, seed=10):
+    # step_signal = np.zeros(lenTS)
+    # # step_signal[start:end+1] = amplitude
+    # step_signal[start:end+1] = np.linspace(5, amplitude, end-start+1)
+    step_signal = np.linspace(5, amplitude, lenTS)
+    step_signal[:start] = step_signal[end+1:] = 0
+    if noise !=0:
+        np.random.seed(seed)
+        noise = noise * np.random.randn(lenTS)
+    s1 = step_signal + noise
+    return s1
+
+
+if __name__ == "__main__":
+
+    # np.random.seed(10)
+
+    # shifts = [0]
+    # threshold = 0.99
+    # minLenInts = [70]
+    # maxLenInts = 200
+    # distancemeasure = 'pearson'
+    #
+    # medfiltlength = 51
+    #
+    # path = '/home/datasets4/fNIRS/fNIRSData/MI-003/Sub1/Sub1_preprocessed.csv'
+    # df = pd.read_csv(path)
+    # data_1 = df.to_numpy()
+    #
+    # path = '/home/datasets4/fNIRS/fNIRSData/MI-003/Sub2/Sub2_preprocessed.csv'
+    # df = pd.read_csv(path)
+    # data_2 = df.to_numpy()
+    #
+    # data1 = standardize(data_1)
+    # data2 = standardize(data_2)
+    #
+    # data1 = smoothing(data1)
+    # data2 = smoothing(data2)
+
+    shifts = [0]  # shifts that must be considered
+    threshold = 0.87 # (imp) depends on the correlation values
+    minLenInts = [50]  # (imp) depends on the minimum continuous interval that must be detected
+    maxLenInts = 250  # (less relevant) depends on the max continuous interval
+    distancemeasure = 'pearson'
+
+    medfiltlength = 51
+
+    S1 = gen_series(start=100, end=300, amplitude=20, noise=1, seed=10)
+    R1 = gen_series(start=120, end=320, amplitude=20, noise=1, seed=11)
+    S2 = gen_series(start=90, end=280, amplitude=20, noise=1, seed=12)
+    R2 = gen_series(start=200, end=260, amplitude=20, noise=1, seed=13)
+
+    data1 = np.column_stack((S1, S2))
+    data2 = np.column_stack((R1, R2))
+
+    result = glamina(data1, data2, threshold=threshold, l_min=minLenInts, l_max=maxLenInts, col_x=[0,1], col_y=[0,1], shifts=shifts, distancemeasure=distancemeasure)
+    Result = median_filter_intervals(result, kernel=medfiltlength)
+
+    # import matplotlib.pyplot as plt
+    # plt.subplot(5, 1, 1)
+    # plt.plot(np.arange(0, len(S1)), np.array(S1), label='x1', color='k')
+    # plt.legend(loc="upper left")
+    # plt.subplot(5, 1, 2)
+    # plt.plot(np.arange(0, len(S2)), np.array(S2), label='x2', color='b')
+    # plt.legend(loc="upper left")
+    # plt.subplot(5, 1, 3)
+    # plt.plot(np.arange(0, len(R1)), np.array(R1), label='y1', color='k')
+    # plt.legend(loc="upper left")
+    # plt.subplot(5, 1, 4)
+    # plt.plot(np.arange(0, len(R2)), np.array(R2), label='y2', color='b')
+    # plt.legend(loc="upper left")
+    # plt.subplot(5, 1, 5)
+    # plt.plot(np.arange(0, len(R2)), np.array(gen_series(start=200, end=260, amplitude=30, noise=0)),
+    #          label='ground truth', color='g')
+    # if len(Result) > 0:
+    #     for idx, i in enumerate(Result):
+    #         plt.axvspan(i[0], (i[0] + i[1] - 1), facecolor='r', alpha=0.5,
+    #                     label=f'detection\nth:{threshold}\nmin:{minLenInts}' if idx == 0 else "")
+    # plt.legend(loc="upper left")
+    # plt.show()
+
+    # S1 = data1[:, 6]
+    # S2 = data1[:, 7]
+    # S3 = data1[:, 8]
+    # R1 = data2[:, 6]
+    # R2 = data2[:, 7]
+    # R3 = data2[:, 8]
+    #
+    # S4 = data1[:, 12]
+    # R4 = data2[:, 12]
+    #
+    # plt.subplot(7, 1, 1)
+    # plt.plot(np.arange(0, len(S1)), np.array(S1), label='x1', color='k')  # color = bgrcmykw
+    # plt.legend(loc="upper left")
+    # plt.subplot(7, 1, 2)
+    # plt.plot(np.arange(0, len(S2)), np.array(S2), label='x2', color='b')
+    # plt.legend(loc="upper left")
+    # plt.subplot(7, 1, 3)
+    # plt.plot(np.arange(0, len(S3)), np.array(S3), label='x3', color='m')
+    # plt.legend(loc="upper left")
+    # plt.subplot(7, 1, 4)
+    # plt.plot(np.arange(0, len(R1)), np.array(R1), label='y1', color='k')
+    # plt.legend(loc="upper left")
+    # plt.subplot(7, 1, 5)
+    # plt.plot(np.arange(0, len(R2)), np.array(R2), label='y2', color='b')
+    # plt.legend(loc="upper left")
+    # plt.subplot(7, 1, 6)
+    # plt.plot(np.arange(0, len(R3)), np.array(R3), label='y3', color='m')
+    # plt.legend(loc="upper left")
+    #
+    # path = '/home/valapil/Project/fnirs_file/annotations/MI-003_clean.csv'
+    # df = pd.read_csv(path)
+    # truth = df.to_numpy()
+    # ind = np.where(truth[:, 1] != 10)[0]
+    # sig = np.zeros(data1.shape[0])
+    # a = truth[ind, :].astype(np.int16)
+    # for s, e in a[:, [0, 2]]:
+    #     sig[s:e + 1] = 1
+    # plt.subplot(7, 1, 7)
+    # plt.plot(np.arange(0, len(sig)), sig, label='ground truth', color='g')
+    #
+    # if len(Result) > 0:
+    #     for idx, i in enumerate(Result):
+    #         plt.axvspan(i[0], (i[0] + i[1] - 1), facecolor='r', alpha=0.5,
+    #                     label=f'detection\nth:{threshold}\nmin:{minLenInts}' if idx == 0 else "")
+    # plt.legend(loc="upper left")
+
+    plt.plot(np.arange(0, len(S1)), np.array(S1), label='x1', color='k')
+    if len(Result) > 0:
+        for idx, i in enumerate(Result):
+            plt.axvspan(i[0], (i[0] + i[1] - 1), facecolor='r', alpha=0.5,
+                        label=f'detection\nth:{threshold}\nmin:{minLenInts}' if idx == 0 else "")
+    plt.legend(loc="upper left")
+    print(Result)
+    plt.show()