06_causality.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. # file with args
  2. import numpy as np
  3. import pandas as pd
  4. from TimeSeriesClass import Intervallist, TimeSeries, TimeSeriesPair
  5. from RelevantIntervalSelection import selectRelevantIntervals
  6. from TSPerformanceMeasure import overlapintervals
  7. import argparse
  8. import concurrent.futures
  9. parser = argparse.ArgumentParser(description='Pass pairs')
  10. parser.add_argument('--pairs', metavar='N', type=str, nargs='+', default=["04"],
  11. help='pass pairs one by one')
  12. args = parser.parse_args()
  13. SR = dict.fromkeys(['HappinessUpper','HappinessLower','SadnessUpper','SadnessLower'],0)
  14. RS = dict.fromkeys(['HappinessUpper','HappinessLower','SadnessUpper','SadnessLower'],0)
  15. bi = dict.fromkeys(['HappinessUpper','HappinessLower','SadnessUpper','SadnessLower'],0)
  16. noc = dict.fromkeys(['HappinessUpper','HappinessLower','SadnessUpper','SadnessLower'],0)
  17. # path to offset data
  18. data = np.load(f'/home/valapil/Project/ForkCausal_Adithya/processed_data.npy')
  19. data = np.delete(data, 0, 0)
  20. # pairs = ["04","05","06","07","08","09","010", "011", "012","013","014","015","016","017","018","019","020","021","022","023",
  21. # "024","025","026","027","028","029","030","031","032","033","034","035","036","037"]
  22. # pairs = ["04"]
  23. relavant={}
  24. distancemeasure = 'pearson'
  25. shifts = [0, 4, 8, 12]
  26. threshold = 0.7
  27. minLenInts = [75]
  28. maxLenInts = 800
  29. expr = {'HappinessUpper': [4], 'HappinessLower': [8,14],'SadnessUpper': [0,2], 'SadnessLower': [10,11]}
  30. # expr = {'HappinessLower': [8, 14]}
  31. model_order_limit = 120
  32. SigLev = 0.05
  33. medfiltlength = 51
  34. ExtendInts = 12
  35. AUS = []
  36. def gc(pair, cond, e_key, e_val, a, b, model_order_limit, SigLev):
  37. length_relevant_intervals = []
  38. result = ["pair", "cond", "expression", "ScE", "EcS", "pval_ScE", "pval_EcS", "order", "r_ScE", "r_EcS", "r_pval_ScE", "r_pval_EcS", "r_order"]
  39. SE = TimeSeriesPair(X=TimeSeries(np.mean(a, axis=1).copy(), varname=e_key + 'S'),
  40. Y=TimeSeries(np.mean(b, axis=1).copy(), varname=e_key + 'E'))
  41. ScE, EcS, pval_ScE, pval_EcS, order = SE.univariateGrangerCausality(
  42. signiveau=SigLev, orderlimit=model_order_limit)
  43. # loop for each au for that expression
  44. for au in e_val:
  45. x = data_S[:, au]
  46. y = data_E[:, au]
  47. ts = TimeSeriesPair(TimeSeries(x), TimeSeries(y))
  48. try:
  49. # passing the intervals to the class if already computed
  50. relIntslocal = relavant[pair][cond][au]
  51. ts.relInt['intervals'] = relIntslocal
  52. except KeyError:
  53. # otherwise compute relevant interval
  54. ts.ComputeRelevantIntervals(shifts=shifts,
  55. threshold=threshold,
  56. minLenInts=minLenInts,
  57. maxLenInts=maxLenInts,
  58. distancemeasure=distancemeasure)
  59. # Add detected relevant intervals to dict
  60. relavant[pair][cond][au] = Intervallist(ts.relInt['intervals'])
  61. RelInts = relavant[pair][cond][e_val[0]]
  62. RelInts.median_filter_intervals(kernel=medfiltlength)
  63. # Loop through remaining Action Units.
  64. if len(e_val) >= 2:
  65. for i, AU in enumerate(e_val[1:]):
  66. # Store and median filter AU
  67. ri = relavant[pair][cond][AU]
  68. ri.median_filter_intervals(kernel=medfiltlength)
  69. # Compute intersection of previous Intervals with current
  70. # Action Unit and store as new current interval selection.
  71. RelInts = overlapintervals(RelInts.medfiltidx,
  72. ri.medfiltidx)
  73. RelInts = Intervallist(RelInts)
  74. RelInts.median_filter_intervals(kernel=medfiltlength)
  75. # Concatenate the relevant intervals
  76. SE.concatRelevantIntervals(relInts=RelInts,
  77. addvalue=ExtendInts)
  78. lengthRelInts = SE.intervals.lengthidx
  79. length_relevant_intervals += [pair, cond, e_key,
  80. lengthRelInts]
  81. if lengthRelInts > 0:
  82. r_ScE, r_EcS, r_pval_ScE, r_pval_Ecs, r_order = SE.univariateGrangerCausality(
  83. onRelInts=True,
  84. signiveau=SigLev, orderlimit=model_order_limit)
  85. temp = np.hstack(
  86. (pair, cond, e_key, ScE, EcS, pval_ScE, pval_EcS, order, r_ScE, r_EcS, r_pval_ScE, r_pval_Ecs, r_order))
  87. result = np.row_stack((result, temp))
  88. print(pair, cond, e_key, '|all:', ScE, EcS, pval_ScE, pval_EcS, order, '|relevant:', r_ScE, r_EcS, r_pval_ScE,
  89. r_pval_Ecs, r_order)
  90. return [pair, cond, e_key, ScE, EcS, pval_ScE, pval_EcS, order, r_ScE, r_EcS, r_pval_ScE,
  91. r_pval_Ecs, r_order]
  92. # print("res fn", result)
  93. # return result
  94. def standardize(x):
  95. """
  96. standardize array, x-mu / sigma (mu: mean, sigma: standard deviation)
  97. :param x: pandas series
  98. :return: standardized array or series
  99. """
  100. mu = np.mean(x, axis=0)
  101. sigma = np.std(x, axis=0)
  102. return (x - mu) / sigma
  103. def smoothing(x):
  104. df = pd.DataFrame(x)
  105. df = df.rolling(window=5, center=True).median()
  106. df = df.rolling(window=5, win_type='gaussian', center=True).mean(std=20)
  107. return df.to_numpy()
  108. with concurrent.futures.ProcessPoolExecutor() as executor:
  109. futures = []
  110. for pair in args.pairs:
  111. relavant[pair] = {}
  112. # for cond in ["o", "g", "a"]:
  113. for cond in ["o"]:
  114. and_gate = np.logical_and(data[:, 0] == pair + "S", data[:, 1] == cond, data[:, 3].astype(np.float64) > 0.89)
  115. index = np.where(and_gate)
  116. data_S = data[index, 3:].reshape(-1, 17).astype(np.float64)
  117. and_gate = np.logical_and(data[:, 0] == pair + "E", data[:, 1] == cond, data[:, 3].astype(np.float64) > 0.89)
  118. index = np.where(and_gate)
  119. data_E = data[index, 3:].reshape(-1, 17).astype(np.float64)
  120. data_S = standardize(data_S)
  121. data_E = standardize(data_E)
  122. data_S = smoothing(data_S)
  123. data_E = smoothing(data_E)
  124. l = int(min(data_S.shape[0], data_E.shape[0]))
  125. data_S = data_S[:l, :]
  126. data_E = data_E[:l, :]
  127. relavant[pair][cond] = {}
  128. # for au in range(3,17):
  129. # x = data_E[:, au]
  130. # y = data_S[:, au]
  131. # relavant[pair][cond][au] = selectRelevantIntervals(X=x,
  132. # Y=y,
  133. # shifts=shifts,
  134. # threshold=threshold,
  135. # minLenInts=minLenInts,
  136. # maxLenInts=maxLenInts,
  137. # distancemeasure=distancemeasure)
  138. # loop for each expression
  139. for e_key, e_val in expr.items():
  140. a = data_S[:, e_val]
  141. b = data_E[:, e_val]
  142. # c = np.mean(data_S[:, e_val], axis=1)
  143. # d = np.mean(data_E[:, e_val], axis=1)
  144. future = executor.submit(gc, pair, cond, e_key, e_val, a, b, model_order_limit, SigLev)
  145. futures.append(future)
  146. # if ScE == True and EcS == True:
  147. # bi[e_key]+=1
  148. # break
  149. # elif ScE == True:
  150. # SR[e_key] +=1
  151. # break
  152. # elif EcS == True:
  153. # RS[e_key] +=1
  154. # break
  155. # else:
  156. # noc[e_key] +=1
  157. # temp = np.hstack((pair, cond, e_key, ScE, EcS, pval_ScE, pval_EcS, order))
  158. # result = np.row_stack((result, temp))
  159. # print(pair, cond, e_key, ScE, EcS, pval_ScE, pval_EcS, order)
  160. result = ["pair", "cond", "expression", "ScE", "EcS", "pval_ScE", "pval_EcS", "order", "r_ScE", "r_EcS",
  161. "r_pval_ScE", "r_pval_EcS", "r_order"]
  162. # completed, _ = concurrent.futures.wait(futures)
  163. for future in concurrent.futures.as_completed(futures):
  164. # for completed_f in completed:
  165. try:
  166. result_row = future.result() # result() returns the result from the executor
  167. # result_row = completed_f.result()
  168. result = np.row_stack((result, result_row))
  169. # result_row.append()
  170. except Exception as e:
  171. print(e)
  172. # store results of causality
  173. np.save(f'/home/valapil/Project/ForkCausal_Adithya/results_causal/{pair}.npy', result)