Przeglądaj źródła

Upload files to ''

valapil 2 lat temu
rodzic
commit
f98fe06fca
2 zmienionych plików z 276 dodań i 0 usunięć
  1. 218 0
      06_causality.py
  2. 58 0
      07_causality_results.py

+ 218 - 0
06_causality.py

@@ -0,0 +1,218 @@
+# file with args
+
+import numpy as np
+import pandas as pd
+from TimeSeriesClass import Intervallist, TimeSeries, TimeSeriesPair
+from RelevantIntervalSelection import selectRelevantIntervals
+from TSPerformanceMeasure import overlapintervals
+import argparse
+import concurrent.futures
+
+parser = argparse.ArgumentParser(description='Pass pairs')
+parser.add_argument('--pairs', metavar='N', type=str, nargs='+', default=["04"],
+                    help='pass pairs one by one')
+args = parser.parse_args()
+
+
+SR = dict.fromkeys(['HappinessUpper','HappinessLower','SadnessUpper','SadnessLower'],0)
+RS = dict.fromkeys(['HappinessUpper','HappinessLower','SadnessUpper','SadnessLower'],0)
+bi = dict.fromkeys(['HappinessUpper','HappinessLower','SadnessUpper','SadnessLower'],0)
+noc = dict.fromkeys(['HappinessUpper','HappinessLower','SadnessUpper','SadnessLower'],0)
+
+# path to offset data
+data = np.load(f'/home/valapil/Project/ForkCausal_Adithya/processed_data.npy')
+data = np.delete(data, 0, 0)
+# pairs = ["04","05","06","07","08","09","010", "011", "012","013","014","015","016","017","018","019","020","021","022","023",
+#          "024","025","026","027","028","029","030","031","032","033","034","035","036","037"]
+# pairs = ["04"]
+
+relavant={}
+
+
+distancemeasure = 'pearson'
+shifts = [0, 4, 8, 12]
+threshold = 0.7
+minLenInts = [75]
+maxLenInts = 800
+
+expr = {'HappinessUpper': [4], 'HappinessLower': [8,14],'SadnessUpper': [0,2], 'SadnessLower': [10,11]}
+# expr = {'HappinessLower': [8, 14]}
+
+model_order_limit = 120
+SigLev = 0.05
+medfiltlength = 51
+ExtendInts = 12
+
+AUS = []
+
+
+def gc(pair, cond, e_key, e_val, a, b, model_order_limit, SigLev):
+    length_relevant_intervals = []
+    result = ["pair", "cond", "expression", "ScE", "EcS", "pval_ScE", "pval_EcS", "order", "r_ScE", "r_EcS", "r_pval_ScE", "r_pval_EcS", "r_order"]
+    SE = TimeSeriesPair(X=TimeSeries(np.mean(a, axis=1).copy(), varname=e_key + 'S'),
+                        Y=TimeSeries(np.mean(b, axis=1).copy(), varname=e_key + 'E'))
+    ScE, EcS, pval_ScE, pval_EcS, order = SE.univariateGrangerCausality(
+        signiveau=SigLev, orderlimit=model_order_limit)
+
+    # loop for each au for that expression
+    for au in e_val:
+        x = data_S[:, au]
+        y = data_E[:, au]
+        ts = TimeSeriesPair(TimeSeries(x), TimeSeries(y))
+
+        try:
+            # passing the intervals to the class if already computed
+            relIntslocal = relavant[pair][cond][au]
+            ts.relInt['intervals'] = relIntslocal
+        except KeyError:
+            # otherwise compute relevant interval
+            ts.ComputeRelevantIntervals(shifts=shifts,
+                                        threshold=threshold,
+                                        minLenInts=minLenInts,
+                                        maxLenInts=maxLenInts,
+                                        distancemeasure=distancemeasure)
+
+            # Add detected relevant intervals to dict
+            relavant[pair][cond][au] = Intervallist(ts.relInt['intervals'])
+
+    RelInts = relavant[pair][cond][e_val[0]]
+    RelInts.median_filter_intervals(kernel=medfiltlength)
+    # Loop through remaining Action Units.
+    if len(e_val) >= 2:
+        for i, AU in enumerate(e_val[1:]):
+            # Store and median filter AU
+            ri = relavant[pair][cond][AU]
+            ri.median_filter_intervals(kernel=medfiltlength)
+
+            # Compute intersection of previous Intervals with current
+            # Action Unit and store as new current interval selection.
+            RelInts = overlapintervals(RelInts.medfiltidx,
+                                       ri.medfiltidx)
+            RelInts = Intervallist(RelInts)
+            RelInts.median_filter_intervals(kernel=medfiltlength)
+
+    # Concatenate the relevant intervals
+    SE.concatRelevantIntervals(relInts=RelInts,
+                               addvalue=ExtendInts)
+    lengthRelInts = SE.intervals.lengthidx
+
+    length_relevant_intervals += [pair, cond, e_key,
+                                  lengthRelInts]
+
+    if lengthRelInts > 0:
+        r_ScE, r_EcS, r_pval_ScE, r_pval_Ecs, r_order = SE.univariateGrangerCausality(
+            onRelInts=True,
+            signiveau=SigLev, orderlimit=model_order_limit)
+
+        temp = np.hstack(
+            (pair, cond, e_key, ScE, EcS, pval_ScE, pval_EcS, order, r_ScE, r_EcS, r_pval_ScE, r_pval_Ecs, r_order))
+        result = np.row_stack((result, temp))
+        print(pair, cond, e_key, '|all:', ScE, EcS, pval_ScE, pval_EcS, order, '|relevant:', r_ScE, r_EcS, r_pval_ScE,
+              r_pval_Ecs, r_order)
+
+    return [pair, cond, e_key, ScE, EcS, pval_ScE, pval_EcS, order, r_ScE, r_EcS, r_pval_ScE,
+              r_pval_Ecs, r_order]
+
+    # print("res fn", result)
+    # return result
+
+def standardize(x):
+    """
+    standardize array, x-mu / sigma (mu: mean, sigma: standard deviation)
+    :param x: pandas series
+    :return: standardized array or series
+    """
+    mu = np.mean(x, axis=0)
+    sigma = np.std(x, axis=0)
+
+    return (x - mu) / sigma
+
+
+def smoothing(x):
+    df = pd.DataFrame(x)
+    df = df.rolling(window=5, center=True).median()
+    df = df.rolling(window=5, win_type='gaussian', center=True).mean(std=20)
+    return df.to_numpy()
+
+
+with concurrent.futures.ProcessPoolExecutor() as executor:
+    futures = []
+    for pair in args.pairs:
+        relavant[pair] = {}
+        # for cond in ["o", "g", "a"]:
+        for cond in ["o"]:
+            and_gate = np.logical_and(data[:, 0] == pair + "S", data[:, 1] == cond, data[:, 3].astype(np.float64) > 0.89)
+            index = np.where(and_gate)
+            data_S = data[index, 3:].reshape(-1, 17).astype(np.float64)
+
+            and_gate = np.logical_and(data[:, 0] == pair + "E", data[:, 1] == cond, data[:, 3].astype(np.float64) > 0.89)
+            index = np.where(and_gate)
+            data_E = data[index, 3:].reshape(-1, 17).astype(np.float64)
+
+            data_S = standardize(data_S)
+            data_E = standardize(data_E)
+
+            data_S = smoothing(data_S)
+            data_E = smoothing(data_E)
+
+            l = int(min(data_S.shape[0], data_E.shape[0]))
+            data_S = data_S[:l, :]
+            data_E = data_E[:l, :]
+
+            relavant[pair][cond] = {}
+
+            # for au in range(3,17):
+            #     x = data_E[:, au]
+            #     y = data_S[:, au]
+                # relavant[pair][cond][au] = selectRelevantIntervals(X=x,
+                #                         Y=y,
+                #                         shifts=shifts,
+                #                         threshold=threshold,
+                #                         minLenInts=minLenInts,
+                #                         maxLenInts=maxLenInts,
+                #                         distancemeasure=distancemeasure)
+
+            # loop for each expression
+            for e_key, e_val in expr.items():
+                a = data_S[:, e_val]
+                b = data_E[:, e_val]
+                # c = np.mean(data_S[:, e_val], axis=1)
+                # d = np.mean(data_E[:, e_val], axis=1)
+                future = executor.submit(gc, pair, cond, e_key, e_val, a, b, model_order_limit, SigLev)
+                futures.append(future)
+
+
+                # if ScE == True and EcS == True:
+                #     bi[e_key]+=1
+                #     break
+                # elif ScE == True:
+                #     SR[e_key] +=1
+                #     break
+                # elif EcS == True:
+                #     RS[e_key] +=1
+                #     break
+                # else:
+                #     noc[e_key] +=1
+
+                # temp = np.hstack((pair, cond, e_key, ScE, EcS, pval_ScE, pval_EcS, order))
+                # result = np.row_stack((result, temp))
+                # print(pair, cond, e_key, ScE, EcS, pval_ScE, pval_EcS, order)
+
+    result = ["pair", "cond", "expression", "ScE", "EcS", "pval_ScE", "pval_EcS", "order", "r_ScE", "r_EcS",
+              "r_pval_ScE", "r_pval_EcS", "r_order"]
+    # completed, _ = concurrent.futures.wait(futures)
+    for future in concurrent.futures.as_completed(futures):
+    # for completed_f in completed:
+        try:
+            result_row = future.result()  # result() returns the result from the executor
+            # result_row = completed_f.result()
+            result = np.row_stack((result, result_row))
+            # result_row.append()
+
+        except Exception as e:
+            print(e)
+
+# store results of causality
+np.save(f'/home/valapil/Project/ForkCausal_Adithya/results_causal/{pair}.npy', result)
+
+

+ 58 - 0
07_causality_results.py

@@ -0,0 +1,58 @@
+# process causality results
+
+from numpy import genfromtxt
+import numpy as np
+import os
+import pandas as pd
+
+
+hu = [0]*8
+hl = [0]*8
+su = [0]*8
+sl = [0]*8
+
+folder = '/home/valapil/Project/ForkCausal_Adithya/results_causal'
+csv_files = [file for file in os.listdir(folder) if file.endswith('.npy')]
+
+def check(row, expr):
+    if row[3] == 'True' and row[4] == 'True':
+        expr[2] += 1
+    if row[3] == 'True' and row[4] == 'False':
+        expr[0] += 1
+    if row[3] == 'False' and row[4] == 'True':
+        expr[1] += 1
+    if row[3] == 'False' and row[4] == 'False':
+        expr[3] += 1
+
+    if row[8] == 'True' and row[9] == 'True':
+        expr[6] += 1
+    if row[8] == 'True' and row[9] == 'False':
+        expr[4] += 1
+    if row[8] == 'False' and row[9] == 'True':
+        expr[5] += 1
+    if row[8] == 'False' and row[9] == 'False':
+        expr[7] += 1
+
+    return expr
+
+for file in csv_files:
+    file_path = os.path.join(folder, file)
+    # data = pd.read_csv(file_path)
+    data = np.load(file_path)
+
+    # for index, row in data.iterrows():
+    for row in data[1:, :]:
+        # row = row.to_list()
+        if row[2] == 'HappinessUpper':
+            hu = check(row, hu)
+        if row[2] == 'HappinessLower':
+            hl = check(row, hl)
+        if row[2] == 'SadnessUpper':
+            su = check(row, su)
+        if row[2] == 'SadnessLower':
+            sl = check(row, sl)
+
+print(hl)
+print(hu)
+print(sl)
+print(su)