glamina.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439
  1. import numpy as np
  2. import pandas as pd
  3. import math
  4. from scipy.signal import medfilt
  5. import matplotlib.pyplot as plt
  6. def standardize(x):
  7. """
  8. standardize array, x-mu / sigma (mu: mean, sigma: standard deviation)
  9. :param x: pandas series
  10. :return: standardized array or series
  11. """
  12. mu = np.mean(x, axis=0)
  13. sigma = np.std(x, axis=0)
  14. return (x - mu) / sigma
  15. def smoothing(x):
  16. df = pd.DataFrame(x)
  17. df = df.rolling(window=5, center=True).median()
  18. df = df.rolling(window=5, win_type='gaussian', center=True).mean(std=20)
  19. return df.to_numpy()
  20. def intervalsToBinary(a, tslength):
  21. """ Transform list of intervals to an binary array of length tslength,
  22. where 1 is interval and 0 is no interval.
  23. E.g. self.idx=[(3,4)], tslength=10 returns [0,0,0,1,1,1,1,0,0,0]"""
  24. gt = np.zeros(tslength)
  25. if len(a) > 0:
  26. for i in a:
  27. gt[int(i[0]):int(i[0]) + int(i[1])] = 1
  28. return gt
  29. def binaryToIntervals(array):
  30. """ Transform binary array, where 1 is interval and 0 is no interval,
  31. to list of intervals.
  32. E.g. array=[0,0,0,1,1,1,1,0,0,0] returns [(3,4)]"""
  33. intervals = []
  34. start = np.where(array == 1)[0][0]
  35. current = np.where(array == 1)[0][0]
  36. count = 1
  37. for val in np.where(array == 1)[0][1:]:
  38. if val-1 == current and val != len(array)-1:
  39. current = val
  40. count += 1
  41. elif val-1 == current and val == len(array)-1:
  42. count += 1
  43. intervals += [start, count]
  44. else:
  45. intervals += [start, count]
  46. start = val
  47. current = val
  48. count = 1
  49. intervals = np.array(intervals).reshape(-1,2)
  50. return intervals
  51. def median_filter_intervals(a, kernel):
  52. tslength = int(np.sum(a[-1]))
  53. binary = intervalsToBinary(a, tslength)
  54. filteredintervals = medfilt(binary, kernel_size=kernel)
  55. filteredintervals = np.maximum.reduce([filteredintervals, binary])
  56. medfiltidx = binaryToIntervals(filteredintervals)
  57. lengthmedfilt = np.sum(np.array(medfiltidx)[:,1])
  58. countmedfilt = len(medfiltidx)
  59. return medfiltidx
  60. def arguments(X, Y, thresCorrelation, minLenInt,
  61. maxLenInt, shiftTS, distancemeasure):
  62. # argss1 and argss2 are used to get all shifts (no-shift, forward and backward) and all minimum length combinations.
  63. argss1 = [(X[shift:], Y[:-shift], thresCorrelation, minLen, maxLenInt, distancemeasure) if shift != 0
  64. else (X, Y, thresCorrelation, minLen, maxLenInt, distancemeasure)
  65. for shift in shiftTS for minLen in minLenInt]
  66. argss2 = [(X[:-shift], Y[shift:], thresCorrelation, minLen, maxLenInt, distancemeasure)
  67. for shift in shiftTS for minLen in minLenInt if shift != 0]
  68. args = argss1 + argss2
  69. return args
  70. def mutual_information_short(hgram):
  71. s = float(np.sum(hgram))
  72. hgram = hgram / s
  73. px = np.sum(hgram, axis=1)
  74. py = np.sum(hgram, axis=0)
  75. px_py = px[:, None] * py[None, :]
  76. nzs = hgram > 0
  77. return np.sum(hgram[nzs] * np.log(hgram[nzs] / px_py[nzs]))
  78. def ts_distance(distancemeasure,X,Y,bins):
  79. """
  80. Similarity measures that can be called in relevant interval selection
  81. approach. Either Peasron correlation or Normalized mutual information.
  82. New similarity measures can be added with other elif statements.
  83. :param distancemeasure: 'pearson' or 'mutual'
  84. :param X: Time series
  85. :param Y: Time series
  86. :param bins: Bins for mutual information computation
  87. :return: returns similarity of X and Y
  88. """
  89. import warnings
  90. if distancemeasure == 'pearson':
  91. with warnings.catch_warnings():
  92. warnings.simplefilter("ignore")
  93. r = np.corrcoef(X, Y)[0, 1]
  94. # k, _ = pearsonr(X, Y)
  95. elif distancemeasure == 'mutual':
  96. #compute best number of bins
  97. #breaks program sometimes...
  98. #k = 10
  99. #sigma = np.corrcoef(X, Y)[0][1]
  100. #if not np.isnan(sigma) or sigma == 1:
  101. # N = len(X)
  102. # k = int(np.ceil(0.5 + 0.5 * np.sqrt(
  103. # 1 + 4 * np.sqrt((6 * N * sigma * sigma) / (
  104. # 1 - (sigma * sigma))))))
  105. # Calculate histograms
  106. concat = np.hstack((X, Y))
  107. l, h = np.nanmin(concat), np.nanmax(concat)
  108. hgram = np.histogram2d(X, Y, range=[[l, h], [l, h]], bins=bins)[0]
  109. r = mutual_information_short(hgram)
  110. # Normalize mutual information
  111. r = np.sqrt(1 - np.exp(-2 * r))
  112. return r
  113. def LongestSet(IntList):
  114. """
  115. Calculate longest set of maximal correlated intervals
  116. :param IntList: 2d list of intervals, e.g. [[0,2], [1,2], [1,4], [4,3],
  117. [7,2], [8,2], [4,3]]
  118. :return: numpy array as list of intervals where sum of length of
  119. intervals is largest regarding all possible combinations of
  120. non-overlapping intervals
  121. """
  122. if type(IntList) == np.ndarray:
  123. IntList = [(x[0], x[1]) for x in IntList]
  124. k = len(IntList)
  125. if k == 0:
  126. return 0
  127. IntList = np.array(sorted(IntList))
  128. NextInt = {}
  129. # Compute immediate starting intervals after i ends
  130. for i, Interval in enumerate(IntList):
  131. end = Interval[0] + Interval[1] # Interval[0] is the starting frame and Interval[1] is the duration
  132. nextidx = IntList[IntList[:, 0] >= end]
  133. if nextidx.size > 0:
  134. minstart = np.min(nextidx[:, 0])
  135. NextInt[i] = np.argwhere(IntList[:, 0] >= minstart)
  136. else:
  137. NextInt[i] = None
  138. # Calculate cumulated sum iterating by end of the list
  139. CumSum = np.zeros(k)
  140. SelInts = {}
  141. for idx, Interval in enumerate(reversed(IntList)):
  142. i = (k - 1) - idx
  143. b = NextInt[i]
  144. if np.all(b != None):
  145. # print(
  146. # CumSum[NextInt[i]] + Interval[1])
  147. bestfollower = int(NextInt[i][np.argmax(
  148. CumSum[NextInt[i]] + Interval[1])])
  149. CumSum[i] = int(Interval[1] + CumSum[bestfollower])
  150. if Interval[1] + CumSum[bestfollower] >= CumSum[i + 1]:
  151. SelInts[i] = bestfollower
  152. else:
  153. CumSum[i] = Interval[1]
  154. SelInts[i] = None
  155. # Loop forward
  156. Result = np.array([])
  157. current = np.where(CumSum == CumSum.max())[0][-1]
  158. while True:
  159. intval = IntList[current]
  160. Result = np.append(Result, [intval])
  161. current = SelInts[current]
  162. if current not in SelInts:
  163. break
  164. Result = Result.reshape(-1, 2)
  165. return Result
  166. def glamina(X: np.ndarray, Y: np.ndarray, threshold: float, l_min: list[int], l_max: int, col_x: list[int] | None =None, col_y: list[int] | None =None, shifts: list[int] | None =None, distancemeasure: str ='pearson', bidirectional: bool = False):
  167. """
  168. Finds the relevant intervals among multiple time series based on correlation between them.
  169. :param X: whole data of subject 1 with columns holding each time series. By default, initial time series
  170. :param Y: whole data of subject 2 with columns holding each time series. By default, following (reaction) time series
  171. :param threshold: threshold value starting from which the correlation is significant
  172. :param l_min: list of minimum length of continuous relevant intervals to be considered
  173. :param l_max: maximum length of continuous relevant intervals
  174. :param col_x: list of selected columns in subject 1. By default, takes in all columns.
  175. :param col_y: list of selected columns in subject 2 with respect to col_x. By default, takes in all columns.
  176. :param shifts: list of values with which X is shifted back and forth in time.
  177. :param distancemeasure: correlation metric. By default, Pearson's correlation
  178. :param bidirectional: True if the influence of both subjects on each other are relevant. By default, influence of subject 1 on 2 is only considered
  179. :return: List of lists containing selected relevant intervals as [starting time, duration].
  180. """
  181. # error handling
  182. if (len(col_x) != len(col_y)):
  183. raise ValueError("Missing respective column indices in function call. There is a shape mismatch.")
  184. if shifts == None:
  185. shifts = [0]
  186. # to select the columns
  187. if (col_x != None and col_y != None):
  188. X = X[:, col_x]
  189. Y = Y[:, col_y]
  190. interval = []
  191. Result = []
  192. laminaresult = []
  193. x = []
  194. y = []
  195. # loop for each minimum interval length
  196. for m in l_min:
  197. if bidirectional:
  198. d1 = [[X[s:, :], Y[:-s, :]] if s != 0 else [X, Y] for s in shifts]
  199. d2 = [[X[:-s, :], Y[s:, :]] for s in shifts if s != 0]
  200. d = d1 + d2
  201. elif not bidirectional:
  202. d = [[X[:-s, :], Y[s:, :]] for s in shifts if s != 0]
  203. for X, Y in d:
  204. indexlength = min(X.shape[0], Y.shape[0])
  205. IntMat = np.zeros([indexlength, l_max - m + 1], dtype=np.float64)
  206. # calculate correlation for minimum interval length
  207. winlen = m
  208. n_channel = X.shape[1]
  209. # j = 0
  210. # loop to average correlation for minimum interval length
  211. for col in range(0, n_channel):
  212. x = X[:, col]
  213. y = Y[:, col]
  214. for i in range(indexlength - m + 1):
  215. r = ts_distance(distancemeasure, x[i:i + winlen],
  216. y[i:i + winlen], bins=10)
  217. if not math.isnan(r):
  218. IntMat[i, 0] += r
  219. # j +=1
  220. IntMat = IntMat/n_channel
  221. # loop to average correlation for intervals above minimum interval
  222. for col in range(0, n_channel):
  223. for winlen in range(m + 1, l_max + 1):
  224. for i in range(indexlength - winlen + 1):
  225. if IntMat[i, winlen - 1 - m] >= threshold and \
  226. IntMat[i + 1, winlen - 1 - m] >= threshold:
  227. r = ts_distance(distancemeasure, x[i:i + winlen],
  228. y[i:i + winlen], bins=10)
  229. if not math.isnan(r):
  230. IntMat[i, winlen - m] += r
  231. IntMat = IntMat[:, 1:] / n_channel
  232. CorInts = np.where(IntMat >= threshold)
  233. del IntMat
  234. CorInts = list(zip(CorInts[0], CorInts[1] + m))
  235. # check if correlated intervals are maximal
  236. ResultInt = []
  237. for i, lenWin in CorInts:
  238. if ((i - 1, lenWin + 1) not in CorInts) and ((i, lenWin + 1) not in CorInts):
  239. ResultInt += [(i, lenWin)]
  240. del CorInts
  241. interval.append(ResultInt)
  242. if len(interval) > 0:
  243. Result = [int for intlist in interval for int in intlist] # flatten the array
  244. if len(Result) > 0:
  245. laminaresult = LongestSet(Result) # find the longest set of intervals from Result
  246. return laminaresult
  247. def gen_series(lenTS=1000, start=50, end=300, amplitude=50, noise=1, seed=10):
  248. # step_signal = np.zeros(lenTS)
  249. # # step_signal[start:end+1] = amplitude
  250. # step_signal[start:end+1] = np.linspace(5, amplitude, end-start+1)
  251. step_signal = np.linspace(5, amplitude, lenTS)
  252. step_signal[:start] = step_signal[end+1:] = 0
  253. if noise !=0:
  254. np.random.seed(seed)
  255. noise = noise * np.random.randn(lenTS)
  256. s1 = step_signal + noise
  257. return s1
  258. if __name__ == "__main__":
  259. # np.random.seed(10)
  260. # shifts = [0]
  261. # threshold = 0.99
  262. # minLenInts = [70]
  263. # maxLenInts = 200
  264. # distancemeasure = 'pearson'
  265. #
  266. # medfiltlength = 51
  267. #
  268. # path = '/home/datasets4/fNIRS/fNIRSData/MI-003/Sub1/Sub1_preprocessed.csv'
  269. # df = pd.read_csv(path)
  270. # data_1 = df.to_numpy()
  271. #
  272. # path = '/home/datasets4/fNIRS/fNIRSData/MI-003/Sub2/Sub2_preprocessed.csv'
  273. # df = pd.read_csv(path)
  274. # data_2 = df.to_numpy()
  275. #
  276. # data1 = standardize(data_1)
  277. # data2 = standardize(data_2)
  278. #
  279. # data1 = smoothing(data1)
  280. # data2 = smoothing(data2)
  281. shifts = [0] # shifts that must be considered
  282. threshold = 0.87 # (imp) depends on the correlation values
  283. minLenInts = [50] # (imp) depends on the minimum continuous interval that must be detected
  284. maxLenInts = 250 # (less relevant) depends on the max continuous interval
  285. distancemeasure = 'pearson'
  286. medfiltlength = 51
  287. S1 = gen_series(start=100, end=300, amplitude=20, noise=1, seed=10)
  288. R1 = gen_series(start=120, end=320, amplitude=20, noise=1, seed=11)
  289. S2 = gen_series(start=90, end=280, amplitude=20, noise=1, seed=12)
  290. R2 = gen_series(start=200, end=260, amplitude=20, noise=1, seed=13)
  291. data1 = np.column_stack((S1, S2))
  292. data2 = np.column_stack((R1, R2))
  293. result = glamina(data1, data2, threshold=threshold, l_min=minLenInts, l_max=maxLenInts, col_x=[0,1], col_y=[0,1], shifts=shifts, distancemeasure=distancemeasure)
  294. Result = median_filter_intervals(result, kernel=medfiltlength)
  295. # import matplotlib.pyplot as plt
  296. # plt.subplot(5, 1, 1)
  297. # plt.plot(np.arange(0, len(S1)), np.array(S1), label='x1', color='k')
  298. # plt.legend(loc="upper left")
  299. # plt.subplot(5, 1, 2)
  300. # plt.plot(np.arange(0, len(S2)), np.array(S2), label='x2', color='b')
  301. # plt.legend(loc="upper left")
  302. # plt.subplot(5, 1, 3)
  303. # plt.plot(np.arange(0, len(R1)), np.array(R1), label='y1', color='k')
  304. # plt.legend(loc="upper left")
  305. # plt.subplot(5, 1, 4)
  306. # plt.plot(np.arange(0, len(R2)), np.array(R2), label='y2', color='b')
  307. # plt.legend(loc="upper left")
  308. # plt.subplot(5, 1, 5)
  309. # plt.plot(np.arange(0, len(R2)), np.array(gen_series(start=200, end=260, amplitude=30, noise=0)),
  310. # label='ground truth', color='g')
  311. # if len(Result) > 0:
  312. # for idx, i in enumerate(Result):
  313. # plt.axvspan(i[0], (i[0] + i[1] - 1), facecolor='r', alpha=0.5,
  314. # label=f'detection\nth:{threshold}\nmin:{minLenInts}' if idx == 0 else "")
  315. # plt.legend(loc="upper left")
  316. # plt.show()
  317. # S1 = data1[:, 6]
  318. # S2 = data1[:, 7]
  319. # S3 = data1[:, 8]
  320. # R1 = data2[:, 6]
  321. # R2 = data2[:, 7]
  322. # R3 = data2[:, 8]
  323. #
  324. # S4 = data1[:, 12]
  325. # R4 = data2[:, 12]
  326. #
  327. # plt.subplot(7, 1, 1)
  328. # plt.plot(np.arange(0, len(S1)), np.array(S1), label='x1', color='k') # color = bgrcmykw
  329. # plt.legend(loc="upper left")
  330. # plt.subplot(7, 1, 2)
  331. # plt.plot(np.arange(0, len(S2)), np.array(S2), label='x2', color='b')
  332. # plt.legend(loc="upper left")
  333. # plt.subplot(7, 1, 3)
  334. # plt.plot(np.arange(0, len(S3)), np.array(S3), label='x3', color='m')
  335. # plt.legend(loc="upper left")
  336. # plt.subplot(7, 1, 4)
  337. # plt.plot(np.arange(0, len(R1)), np.array(R1), label='y1', color='k')
  338. # plt.legend(loc="upper left")
  339. # plt.subplot(7, 1, 5)
  340. # plt.plot(np.arange(0, len(R2)), np.array(R2), label='y2', color='b')
  341. # plt.legend(loc="upper left")
  342. # plt.subplot(7, 1, 6)
  343. # plt.plot(np.arange(0, len(R3)), np.array(R3), label='y3', color='m')
  344. # plt.legend(loc="upper left")
  345. #
  346. # path = '/home/valapil/Project/fnirs_file/annotations/MI-003_clean.csv'
  347. # df = pd.read_csv(path)
  348. # truth = df.to_numpy()
  349. # ind = np.where(truth[:, 1] != 10)[0]
  350. # sig = np.zeros(data1.shape[0])
  351. # a = truth[ind, :].astype(np.int16)
  352. # for s, e in a[:, [0, 2]]:
  353. # sig[s:e + 1] = 1
  354. # plt.subplot(7, 1, 7)
  355. # plt.plot(np.arange(0, len(sig)), sig, label='ground truth', color='g')
  356. #
  357. # if len(Result) > 0:
  358. # for idx, i in enumerate(Result):
  359. # plt.axvspan(i[0], (i[0] + i[1] - 1), facecolor='r', alpha=0.5,
  360. # label=f'detection\nth:{threshold}\nmin:{minLenInts}' if idx == 0 else "")
  361. # plt.legend(loc="upper left")
  362. plt.plot(np.arange(0, len(S1)), np.array(S1), label='x1', color='k')
  363. if len(Result) > 0:
  364. for idx, i in enumerate(Result):
  365. plt.axvspan(i[0], (i[0] + i[1] - 1), facecolor='r', alpha=0.5,
  366. label=f'detection\nth:{threshold}\nmin:{minLenInts}' if idx == 0 else "")
  367. plt.legend(loc="upper left")
  368. print(Result)
  369. plt.show()