glamina.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. import numpy as np
  2. import pandas as pd
  3. import math
  4. from scipy.signal import medfilt
  5. import matplotlib.pyplot as plt
  6. def standardize(x):
  7. """
  8. standardize array, x-mu / sigma (mu: mean, sigma: standard deviation)
  9. :param x: pandas series
  10. :return: standardized array or series
  11. """
  12. mu = np.mean(x, axis=0)
  13. sigma = np.std(x, axis=0)
  14. return (x - mu) / sigma
  15. def smoothing(x):
  16. df = pd.DataFrame(x)
  17. df = df.rolling(window=5, center=True).median()
  18. df = df.rolling(window=5, win_type='gaussian', center=True).mean(std=20)
  19. return df.to_numpy()
  20. def intervalsToBinary(a, tslength):
  21. """ Transform list of intervals to an binary array of length tslength,
  22. where 1 is interval and 0 is no interval.
  23. E.g. self.idx=[(3,4)], tslength=10 returns [0,0,0,1,1,1,1,0,0,0]"""
  24. gt = np.zeros(tslength)
  25. if len(a) > 0:
  26. for i in a:
  27. gt[int(i[0]):int(i[0]) + int(i[1])] = 1
  28. return gt
  29. def binaryToIntervals(array):
  30. """ Transform binary array, where 1 is interval and 0 is no interval,
  31. to list of intervals.
  32. E.g. array=[0,0,0,1,1,1,1,0,0,0] returns [(3,4)]"""
  33. intervals = []
  34. start = np.where(array == 1)[0][0]
  35. current = np.where(array == 1)[0][0]
  36. count = 1
  37. for val in np.where(array == 1)[0][1:]:
  38. if val-1 == current and val != len(array)-1:
  39. current = val
  40. count += 1
  41. elif val-1 == current and val == len(array)-1:
  42. count += 1
  43. intervals += [start, count]
  44. else:
  45. intervals += [start, count]
  46. start = val
  47. current = val
  48. count = 1
  49. intervals = np.array(intervals).reshape(-1,2)
  50. return intervals
  51. def median_filter_intervals(a, kernel):
  52. tslength = int(np.sum(a[-1]))
  53. binary = intervalsToBinary(a, tslength)
  54. filteredintervals = medfilt(binary, kernel_size=kernel)
  55. filteredintervals = np.maximum.reduce([filteredintervals, binary])
  56. medfiltidx = binaryToIntervals(filteredintervals)
  57. lengthmedfilt = np.sum(np.array(medfiltidx)[:,1])
  58. countmedfilt = len(medfiltidx)
  59. return medfiltidx
  60. def arguments(X, Y, thresCorrelation, minLenInt,
  61. maxLenInt, shiftTS, distancemeasure):
  62. # argss1 and argss2 are used to get all shifts (no-shift, forward and backward) and all minimum length combinations.
  63. argss1 = [(X[shift:], Y[:-shift], thresCorrelation, minLen, maxLenInt, distancemeasure) if shift != 0
  64. else (X, Y, thresCorrelation, minLen, maxLenInt, distancemeasure)
  65. for shift in shiftTS for minLen in minLenInt]
  66. argss2 = [(X[:-shift], Y[shift:], thresCorrelation, minLen, maxLenInt, distancemeasure)
  67. for shift in shiftTS for minLen in minLenInt if shift != 0]
  68. args = argss1 + argss2
  69. return args
  70. def mutual_information_short(hgram):
  71. s = float(np.sum(hgram))
  72. hgram = hgram / s
  73. px = np.sum(hgram, axis=1)
  74. py = np.sum(hgram, axis=0)
  75. px_py = px[:, None] * py[None, :]
  76. nzs = hgram > 0
  77. return np.sum(hgram[nzs] * np.log(hgram[nzs] / px_py[nzs]))
  78. def ts_distance(distancemeasure,X,Y,bins):
  79. """
  80. Similarity measures that can be called in relevant interval selection
  81. approach. Either Peasron correlation or Normalized mutual information.
  82. New similarity measures can be added with other elif statements.
  83. :param distancemeasure: 'pearson' or 'mutual'
  84. :param X: Time series
  85. :param Y: Time series
  86. :param bins: Bins for mutual information computation
  87. :return: returns similarity of X and Y
  88. """
  89. import warnings
  90. if distancemeasure == 'pearson':
  91. with warnings.catch_warnings():
  92. warnings.simplefilter("ignore")
  93. r = np.corrcoef(X, Y)[0, 1]
  94. # k, _ = pearsonr(X, Y)
  95. elif distancemeasure == 'mutual':
  96. #compute best number of bins
  97. #breaks program sometimes...
  98. #k = 10
  99. #sigma = np.corrcoef(X, Y)[0][1]
  100. #if not np.isnan(sigma) or sigma == 1:
  101. # N = len(X)
  102. # k = int(np.ceil(0.5 + 0.5 * np.sqrt(
  103. # 1 + 4 * np.sqrt((6 * N * sigma * sigma) / (
  104. # 1 - (sigma * sigma))))))
  105. # Calculate histograms
  106. concat = np.hstack((X, Y))
  107. l, h = np.nanmin(concat), np.nanmax(concat)
  108. hgram = np.histogram2d(X, Y, range=[[l, h], [l, h]], bins=bins)[0]
  109. r = mutual_information_short(hgram)
  110. # Normalize mutual information
  111. r = np.sqrt(1 - np.exp(-2 * r))
  112. return r
  113. def LongestSet(IntList):
  114. """
  115. Calculate longest set of maximal correlated intervals
  116. :param IntList: 2d list of intervals, e.g. [[0,2], [1,2], [1,4], [4,3],
  117. [7,2], [8,2], [4,3]]
  118. :return: numpy array as list of intervals where sum of length of
  119. intervals is largest regarding all possible combinations of
  120. non-overlapping intervals
  121. """
  122. if type(IntList) == np.ndarray:
  123. IntList = [(x[0], x[1]) for x in IntList]
  124. k = len(IntList)
  125. if k == 0:
  126. return 0
  127. IntList = np.array(sorted(IntList))
  128. NextInt = {}
  129. # Compute immediate starting intervals after i ends
  130. for i, Interval in enumerate(IntList):
  131. end = Interval[0] + Interval[1] # Interval[0] is the starting frame and Interval[1] is the duration
  132. nextidx = IntList[IntList[:, 0] >= end]
  133. if nextidx.size > 0:
  134. minstart = np.min(nextidx[:, 0])
  135. NextInt[i] = np.argwhere(IntList[:, 0] >= minstart)
  136. else:
  137. NextInt[i] = None
  138. # Calculate cumulated sum iterating by end of the list
  139. CumSum = np.zeros(k)
  140. SelInts = {}
  141. for idx, Interval in enumerate(reversed(IntList)):
  142. i = (k - 1) - idx
  143. b = NextInt[i]
  144. if np.all(b != None):
  145. # print(
  146. # CumSum[NextInt[i]] + Interval[1])
  147. bestfollower = int(NextInt[i][np.argmax(
  148. CumSum[NextInt[i]] + Interval[1])])
  149. CumSum[i] = int(Interval[1] + CumSum[bestfollower])
  150. if Interval[1] + CumSum[bestfollower] >= CumSum[i + 1]:
  151. SelInts[i] = bestfollower
  152. else:
  153. CumSum[i] = Interval[1]
  154. SelInts[i] = None
  155. # Loop forward
  156. Result = np.array([])
  157. current = np.where(CumSum == CumSum.max())[0][-1]
  158. while True:
  159. intval = IntList[current]
  160. Result = np.append(Result, [intval])
  161. current = SelInts[current]
  162. if current not in SelInts:
  163. break
  164. Result = Result.reshape(-1, 2)
  165. return Result
  166. def glamina(X: np.ndarray, Y: np.ndarray, threshold: float, l_min: list[int], l_max: int, col_x: list[int] | None =None, col_y: list[int] | None =None, shifts: list[int] | None =None, distancemeasure: str ='pearson'):
  167. """
  168. Finds the relevant intervals among multiple time series based on correlation between them.
  169. :param X: whole data of subject 1 with columns holding each time series
  170. :param Y: whole data of subject 2 with columns holding each time series
  171. :param threshold: threshold value starting from which the correlation is significant
  172. :param l_min: list of minimum length of continuous relevant intervals to be considered
  173. :param l_max: maximum length of continuous relevant intervals
  174. :param col_x: list of selected columns in subject 1. By default, takes in all columns.
  175. :param col_y: list of selected columns in subject 2 with respect to col_x. By default, takes in all columns.
  176. :param shifts: list of values with which X is shifted back and forth in time.
  177. :param distancemeasure: correlation metric. By default, Pearson's correlation
  178. :return: List of lists containing selected relevant intervals as [starting time, duration].
  179. """
  180. # error handling
  181. if (len(col_x) != len(col_y)):
  182. raise ValueError("Missing respective column indices in function call. There is a shape mismatch.")
  183. if shifts == None:
  184. shifts = [0]
  185. # to select the columns
  186. if (col_x != None and col_y != None):
  187. X = X[:, col_x]
  188. Y = Y[:, col_y]
  189. interval = []
  190. Result = []
  191. laminaresult = []
  192. x = []
  193. y = []
  194. # loop for each minimum interval length
  195. for m in l_min:
  196. d1 = [[X[s:, :], Y[:-s, :]] if s != 0 else [X, Y] for s in shifts]
  197. d2 = [[X[:-s, :], Y[s:, :]] for s in shifts if s != 0]
  198. d = d1 + d2
  199. for X, Y in d:
  200. indexlength = min(X.shape[0], Y.shape[0])
  201. IntMat = np.zeros([indexlength, l_max - m + 1], dtype=np.float64)
  202. # calculate correlation for minimum interval length
  203. winlen = m
  204. n_channel = X.shape[1]
  205. # j = 0
  206. # loop to average correlation for minimum interval length
  207. for col in range(0, n_channel):
  208. x = X[:, col]
  209. y = Y[:, col]
  210. for i in range(indexlength - m + 1):
  211. r = ts_distance(distancemeasure, x[i:i + winlen],
  212. y[i:i + winlen], bins=10)
  213. if not math.isnan(r):
  214. IntMat[i, 0] += r
  215. # j +=1
  216. IntMat = IntMat/n_channel
  217. # loop to average correlation for intervals above minimum interval
  218. for col in range(0, n_channel):
  219. for winlen in range(m + 1, l_max + 1):
  220. for i in range(indexlength - winlen + 1):
  221. if IntMat[i, winlen - 1 - m] >= threshold and \
  222. IntMat[i + 1, winlen - 1 - m] >= threshold:
  223. r = ts_distance(distancemeasure, x[i:i + winlen],
  224. y[i:i + winlen], bins=10)
  225. if not math.isnan(r):
  226. IntMat[i, winlen - m] += r
  227. IntMat = IntMat[:, 1:] / n_channel
  228. CorInts = np.where(IntMat >= threshold)
  229. del IntMat
  230. CorInts = list(zip(CorInts[0], CorInts[1] + m))
  231. # check if correlated intervals are maximal
  232. ResultInt = []
  233. for i, lenWin in CorInts:
  234. if ((i - 1, lenWin + 1) not in CorInts) and ((i, lenWin + 1) not in CorInts):
  235. ResultInt += [(i, lenWin)]
  236. del CorInts
  237. interval.append(ResultInt)
  238. if len(interval) > 0:
  239. Result = [int for intlist in interval for int in intlist] # flatten the array
  240. if len(Result) > 0:
  241. laminaresult = LongestSet(Result) # find the longest set of intervals from Result
  242. return laminaresult
  243. def gen_series(lenTS=1000, start=50, end=300, amplitude=50, noise=1, seed=10):
  244. # step_signal = np.zeros(lenTS)
  245. # # step_signal[start:end+1] = amplitude
  246. # step_signal[start:end+1] = np.linspace(5, amplitude, end-start+1)
  247. step_signal = np.linspace(5, amplitude, lenTS)
  248. step_signal[:start] = step_signal[end+1:] = 0
  249. if noise !=0:
  250. np.random.seed(seed)
  251. noise = noise * np.random.randn(lenTS)
  252. s1 = step_signal + noise
  253. return s1
  254. if __name__ == "__main__":
  255. # np.random.seed(10)
  256. # shifts = [0]
  257. # threshold = 0.99
  258. # minLenInts = [70]
  259. # maxLenInts = 200
  260. # distancemeasure = 'pearson'
  261. #
  262. # medfiltlength = 51
  263. #
  264. # path = '/home/datasets4/fNIRS/fNIRSData/MI-003/Sub1/Sub1_preprocessed.csv'
  265. # df = pd.read_csv(path)
  266. # data_1 = df.to_numpy()
  267. #
  268. # path = '/home/datasets4/fNIRS/fNIRSData/MI-003/Sub2/Sub2_preprocessed.csv'
  269. # df = pd.read_csv(path)
  270. # data_2 = df.to_numpy()
  271. #
  272. # data1 = standardize(data_1)
  273. # data2 = standardize(data_2)
  274. #
  275. # data1 = smoothing(data1)
  276. # data2 = smoothing(data2)
  277. shifts = [0] # shifts that must be considered
  278. threshold = 0.87 # (imp) depends on the correlation values
  279. minLenInts = [50] # (imp) depends on the minimum continuous interval that must be detected
  280. maxLenInts = 250 # (less relevant) depends on the max continuous interval
  281. distancemeasure = 'pearson'
  282. medfiltlength = 51
  283. S1 = gen_series(start=100, end=300, amplitude=20, noise=1, seed=10)
  284. R1 = gen_series(start=120, end=320, amplitude=20, noise=1, seed=11)
  285. S2 = gen_series(start=90, end=280, amplitude=20, noise=1, seed=12)
  286. R2 = gen_series(start=200, end=260, amplitude=20, noise=1, seed=13)
  287. data1 = np.column_stack((S1, S2))
  288. data2 = np.column_stack((R1, R2))
  289. result = glamina(data1, data2, threshold=threshold, l_min=minLenInts, l_max=maxLenInts, col_x=[0,1], col_y=[0,1], shifts=shifts, distancemeasure=distancemeasure)
  290. Result = median_filter_intervals(result, kernel=medfiltlength)
  291. # import matplotlib.pyplot as plt
  292. # plt.subplot(5, 1, 1)
  293. # plt.plot(np.arange(0, len(S1)), np.array(S1), label='x1', color='k')
  294. # plt.legend(loc="upper left")
  295. # plt.subplot(5, 1, 2)
  296. # plt.plot(np.arange(0, len(S2)), np.array(S2), label='x2', color='b')
  297. # plt.legend(loc="upper left")
  298. # plt.subplot(5, 1, 3)
  299. # plt.plot(np.arange(0, len(R1)), np.array(R1), label='y1', color='k')
  300. # plt.legend(loc="upper left")
  301. # plt.subplot(5, 1, 4)
  302. # plt.plot(np.arange(0, len(R2)), np.array(R2), label='y2', color='b')
  303. # plt.legend(loc="upper left")
  304. # plt.subplot(5, 1, 5)
  305. # plt.plot(np.arange(0, len(R2)), np.array(gen_series(start=200, end=260, amplitude=30, noise=0)),
  306. # label='ground truth', color='g')
  307. # if len(Result) > 0:
  308. # for idx, i in enumerate(Result):
  309. # plt.axvspan(i[0], (i[0] + i[1] - 1), facecolor='r', alpha=0.5,
  310. # label=f'detection\nth:{threshold}\nmin:{minLenInts}' if idx == 0 else "")
  311. # plt.legend(loc="upper left")
  312. # plt.show()
  313. # S1 = data1[:, 6]
  314. # S2 = data1[:, 7]
  315. # S3 = data1[:, 8]
  316. # R1 = data2[:, 6]
  317. # R2 = data2[:, 7]
  318. # R3 = data2[:, 8]
  319. #
  320. # S4 = data1[:, 12]
  321. # R4 = data2[:, 12]
  322. #
  323. # plt.subplot(7, 1, 1)
  324. # plt.plot(np.arange(0, len(S1)), np.array(S1), label='x1', color='k') # color = bgrcmykw
  325. # plt.legend(loc="upper left")
  326. # plt.subplot(7, 1, 2)
  327. # plt.plot(np.arange(0, len(S2)), np.array(S2), label='x2', color='b')
  328. # plt.legend(loc="upper left")
  329. # plt.subplot(7, 1, 3)
  330. # plt.plot(np.arange(0, len(S3)), np.array(S3), label='x3', color='m')
  331. # plt.legend(loc="upper left")
  332. # plt.subplot(7, 1, 4)
  333. # plt.plot(np.arange(0, len(R1)), np.array(R1), label='y1', color='k')
  334. # plt.legend(loc="upper left")
  335. # plt.subplot(7, 1, 5)
  336. # plt.plot(np.arange(0, len(R2)), np.array(R2), label='y2', color='b')
  337. # plt.legend(loc="upper left")
  338. # plt.subplot(7, 1, 6)
  339. # plt.plot(np.arange(0, len(R3)), np.array(R3), label='y3', color='m')
  340. # plt.legend(loc="upper left")
  341. #
  342. # path = '/home/valapil/Project/fnirs_file/annotations/MI-003_clean.csv'
  343. # df = pd.read_csv(path)
  344. # truth = df.to_numpy()
  345. # ind = np.where(truth[:, 1] != 10)[0]
  346. # sig = np.zeros(data1.shape[0])
  347. # a = truth[ind, :].astype(np.int16)
  348. # for s, e in a[:, [0, 2]]:
  349. # sig[s:e + 1] = 1
  350. # plt.subplot(7, 1, 7)
  351. # plt.plot(np.arange(0, len(sig)), sig, label='ground truth', color='g')
  352. #
  353. # if len(Result) > 0:
  354. # for idx, i in enumerate(Result):
  355. # plt.axvspan(i[0], (i[0] + i[1] - 1), facecolor='r', alpha=0.5,
  356. # label=f'detection\nth:{threshold}\nmin:{minLenInts}' if idx == 0 else "")
  357. # plt.legend(loc="upper left")
  358. plt.plot(np.arange(0, len(S1)), np.array(S1), label='x1', color='k')
  359. if len(Result) > 0:
  360. for idx, i in enumerate(Result):
  361. plt.axvspan(i[0], (i[0] + i[1] - 1), facecolor='r', alpha=0.5,
  362. label=f'detection\nth:{threshold}\nmin:{minLenInts}' if idx == 0 else "")
  363. plt.legend(loc="upper left")
  364. print(Result)
  365. plt.show()