eye_state_prototype.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. import numpy as np
  2. import pandas as pd
  3. import stumpy
  4. from scipy import signal
  5. import matplotlib.pyplot as plt
  6. ##########Prototyping##########
  7. # prototype extraction
  8. def motif_extraction(ear_ts: np.ndarray, m=100, max_matches=10):
  9. """Extract top motifs in EAR time series."""
  10. mp = stumpy.stump(ear_ts, m) # matrix profile
  11. motif_distances, motif_indices = stumpy.motifs(ear_ts, mp[:, 0], max_matches=max_matches)
  12. return motif_distances, motif_indices
  13. def learn_prototypes(ear_ts: np.ndarray, m=100, max_matches=10):
  14. """Return top motifs."""
  15. _, motif_indices = motif_extraction(ear_ts, m, max_matches)
  16. motifs = np.array([ear_ts[idx:idx+m] for idx in (motif_indices[0] + m // 2)])
  17. return motifs
  18. # manual definition
  19. def combined_gaussian(sig1: float, sig2: float, avg: float, prom: float, m=100, mu=40, noise=None):
  20. """Manual prototype composed of two Gaussians."""
  21. y1 = - prom * signal.gaussian(2*m, std=sig1) + avg
  22. y2 = - prom * signal.gaussian(2*m, std=sig2) + avg
  23. y = np.append(y1[:m], y2[m:])
  24. if noise is not None:
  25. y = y + noise
  26. return y[m-mu:2*m-mu]
  27. def nosie(noise_std: float, m=100):
  28. "Random noise based on learned data."
  29. np.random.seed(0)
  30. noise = (np.random.random(2*100) * 2 - 1) * noise_std
  31. return noise
  32. ##########Detection##########
  33. # blink pattern matching
  34. def fpm(Q: np.ndarray, T: np.ndarray, th=3.0):
  35. """Fast Pattern Matching"""
  36. def threshold(D):
  37. return np.nanmax([np.nanmean(D) - th * np.std(D), np.nanmin(D)])
  38. matches = stumpy.match(Q, T, max_distance=threshold)
  39. # match_indices = matches[:, 1]
  40. return matches
  41. def index_matching(indices1: np.ndarray, indices2: np.ndarray, max_distance=50):
  42. "Match indices saved in two arrays."
  43. matched_pairs = []
  44. no_match = []
  45. for idx1 in indices1:
  46. dists = np.abs(indices2 - idx1)
  47. min_dist = np.min(dists)
  48. if min_dist < max_distance:
  49. matched_pairs.append([idx1, indices2[np.argmin(dists)]]) # when there are two equal-dist matches, always keep the first onr
  50. else:
  51. no_match.append(idx1)
  52. return np.array(matched_pairs), np.array(no_match)
  53. # find peaks
  54. def find_peaks_in_ear_ts(ear_ts: np.ndarray, h_th=0.15, p_th=None, t_th=None, d_th=50):
  55. """
  56. Find peaks in EAR time series.
  57. h_th = 0.15 # height threshold
  58. p_th = None # prominence threshold
  59. t_th = None # threshold
  60. d_th = 50 # distance threshold
  61. """
  62. peaks, properties = signal.find_peaks(-ear_ts, height=-h_th, threshold=t_th, prominence=p_th, distance=d_th)
  63. heights = - properties["peak_heights"]
  64. return peaks, heights
  65. def cal_bpm_results(ear_r: np.ndarray, ear_l: np.ndarray, prototype, fpm_th=3.0, h_th=0.15, p_th=None, t_th=None, d_th=50, save_path=None):
  66. """Caculate and save find peaks and fast pattern matching results."""
  67. # fast pattern matching
  68. m = len(prototype)
  69. matches_r = fpm(Q=prototype, T=ear_r, th=fpm_th)
  70. matches_l = fpm(Q=prototype, T=ear_l, th=fpm_th)
  71. match_indices_r = matches_r[:, 1]
  72. match_indices_l = matches_l[:, 1]
  73. # index matching
  74. sorted_indices_r = np.sort(match_indices_r)
  75. sorted_indices_l = np.sort(match_indices_l)
  76. matched_pairs, no_match = index_matching(sorted_indices_r, sorted_indices_l, max_distance=50)
  77. # find peaks
  78. peaks_r, heights_r = find_peaks_in_ear_ts(ear_r, h_th, t_th, p_th, d_th)
  79. peaks_l, heights_l = find_peaks_in_ear_ts(ear_l, h_th, t_th, p_th, d_th)
  80. # save results
  81. results = {}
  82. results["match_indices_r"] = matches_r[:, 1]
  83. results["match_values_r"] = matches_r[:, 0]
  84. results["match_indices_l"] = matches_l[:, 1]
  85. results["match_values_l"] = matches_l[:, 0]
  86. results["sorted_indices_r"] = sorted_indices_r
  87. results["sorted_indices_l"] = sorted_indices_l
  88. results["matched_pairs_r"] = matched_pairs[:, 0]
  89. results["matched_pairs_l"] = matched_pairs[:, 1]
  90. results["peaks_r"] = peaks_r
  91. results["heights_r"] = heights_r
  92. results["peaks_l"] = peaks_l
  93. results["heights_l"] = heights_l
  94. if save_path is not None:
  95. results_df = pd.DataFrame({key:pd.Series(value) for key, value in results.items()})
  96. results_df.to_csv(save_path)
  97. return results
  98. ##########Analysis##########
  99. def get_apex(T: np.ndarray, m: int, match_indices: np.ndarray):
  100. """Estimated apex in each extracted matches."""
  101. apex_indices = []
  102. apex_proms = []
  103. for idx in match_indices:
  104. apex_prom = np.max(T[idx:idx+m]) - np.min(T[idx:idx+m])
  105. apex_idx = idx + np.argmin(T[idx:idx+m])
  106. apex_indices.append(apex_idx)
  107. apex_proms.append(apex_prom)
  108. return np.array(apex_indices), np.array(apex_proms)
  109. def get_stats(diff: np.ndarray) -> dict:
  110. """Get statistics (avg, std, median) and save them in a dict."""
  111. diff_stats = dict()
  112. diff_stats["avg"] = np.mean(diff)
  113. diff_stats["std"] = np.std(diff)
  114. diff_stats["median"] = np.median(diff)
  115. return diff_stats
  116. # other utilities
  117. def smooth(data: np.ndarray, window_len=5, window="flat"):
  118. "Function for smoothing the data. For now, window type: the moving average (flat)."
  119. if data.ndim != 1:
  120. raise ValueError("Only accept 1D array as input.")
  121. if data.size < window_len:
  122. raise ValueError("The input data should be larger than the window size.")
  123. if window == "flat":
  124. kernel = np.ones(window_len) / window_len
  125. s = np.convolve(data, kernel, mode="same")
  126. return s
  127. ##########Analysis##########
  128. def plot_ear(ear_r: np.ndarray, ear_l: np.ndarray, xmin=0, xmax=20000):
  129. """Plot right and left ear score."""
  130. fig, axs = plt.subplots(2, figsize=(20, 6), sharex=True, sharey=True, gridspec_kw={'hspace': 0})
  131. axs[0].plot(ear_r, c='r', label='right eye')
  132. axs[0].minorticks_on()
  133. if len(ear_r) < xmax:
  134. xmax = len(ear_r)
  135. axs[0].set_xlim([xmin, xmax])
  136. axs[0].set_title("EAR Time Series", fontsize="30")
  137. axs[0].set_ylabel('right', fontsize="18")
  138. axs[1].plot(ear_l, c='b', label='left eye')
  139. axs[1].set_ylabel('left', fontsize="18")
  140. axs[1].set_xlabel('Frame', fontsize="18")
  141. plt.show()
  142. return True
  143. def plot_mp(ts: np.ndarray, mp: np.ndarray):
  144. """Plot EAR and Matrix Profile."""
  145. fig, axs = plt.subplots(2, figsize=(10, 6), sharex=True, gridspec_kw={'hspace': 0})
  146. plt.suptitle('EAR Score and Matrix Profile', fontsize='30')
  147. axs[0].plot(ts)
  148. axs[0].set_ylabel('EAR', fontsize='20')
  149. axs[0].set(xlim=[0, len(ts)], ylim=[0, 0.4])
  150. axs[0].minorticks_on()
  151. axs[1].set_xlabel('Frame', fontsize ='20')
  152. axs[1].set_ylabel('Matrix Profile', fontsize='20')
  153. axs[1].plot(mp[:, 0])
  154. plt.show()
  155. return True
  156. def plot_zoom_in(ax, ear, sorted_indices, peaks, heights, subregion, zoom_in_box, m=100):
  157. """Zoom in subregion on the orginial plot."""
  158. x1, x2, y1, y2 = subregion
  159. x_in, y_in, w_in, h_in = zoom_in_box
  160. axin = ax.inset_axes([x_in, y_in, w_in, h_in],
  161. xlim=(x1, x2), ylim=(y1, y2),
  162. xticklabels=[], yticklabels=[])
  163. for i, match_idx in enumerate(sorted_indices):
  164. axin.axvspan(match_idx, match_idx + m, 0, 1, facecolor="lightgrey")
  165. axin.plot(ear, c="r", zorder=1)
  166. axin.scatter(peaks, heights, marker='x', zorder=2)
  167. axin.set_xticks([])
  168. axin.set_yticks([])
  169. ax.indicate_inset_zoom(axin, edgecolor="black")
  170. def plot_results(ax, ears, results, side, m=100, h_th=0.15, xmin=0, xmax=10000, ymin=-0.1, ymax=0.5, zoom_in_params=None):
  171. """Plot fast pattern matching vs simple thresholding results for each EAR time series."""
  172. # set values
  173. if side == 'right':
  174. c = 'r'
  175. ear = ears[0]
  176. sorted_indices = results["sorted_indices_r"]
  177. peaks = results["peaks_r"]
  178. heights = results["heights_r"]
  179. else:
  180. c = 'b'
  181. ear = ears[1]
  182. sorted_indices = results["sorted_indices_l"]
  183. peaks = results["peaks_l"]
  184. heights = results["heights_l"]
  185. # EAR time series
  186. ax.plot(ear, c=c, zorder=1)
  187. # find peaks
  188. ax.hlines(h_th, xmin, xmax, linestyles='dashed', zorder=0) # showing threshold
  189. ax.scatter(peaks, heights, marker='x', zorder=2)
  190. # fpm detected regions
  191. for i, match_idx in enumerate(sorted_indices):
  192. ax.axvspan(match_idx, match_idx + m, 0, 1, facecolor="lightgrey", zorder=-1)
  193. # show numbering
  194. plot_indices = sorted_indices [(sorted_indices > xmin) & (sorted_indices < xmax)]
  195. for j in range(len(plot_indices)):
  196. if np.diff(plot_indices)[j-1] <= 200:
  197. ax.text(plot_indices[j]+50, 0.38, str(j+1), fontsize=10)
  198. else:
  199. ax.text(plot_indices[j]-100, 0.38, str(j+1), fontsize=10)
  200. if zoom_in_params != None:
  201. subregion, zoom_in_box = zoom_in_params
  202. plot_zoom_in(ax, ear, sorted_indices, peaks, heights, subregion, zoom_in_box)
  203. ax.set(xlim=[xmin, xmax], ylim=[ymin, ymax])
  204. if side == 'right':
  205. ax.set_xticks([])
  206. else:
  207. ax.set_xlabel("Frame", fontsize=10)
  208. # other plotting functions for visualization
  209. def plot_fpm(T: np.ndarray, match_indices: np.ndarray, m: int, th: float, xmin=0, xmax=20000, save_path=None, show=False):
  210. """Plot fast pattern matching, one graph."""
  211. fig, ax = plt.subplots(figsize=(20, 6), sharex=True, gridspec_kw={'hspace': 0})
  212. plt.suptitle('Fast Pattern Matching', fontsize='20')
  213. ax.plot(T)
  214. ax.set_ylabel('EAR', fontsize='14')
  215. for i, match_idx in enumerate(match_indices):
  216. ax.axvspan(match_indices[i], match_indices[i] + m, 0, 1, facecolor="lightgrey")
  217. #ax.text(match_indices[i], 0, str(i+1), color="black", fontsize=20)
  218. ax.set(xlim=[xmin, xmax])
  219. ax.minorticks_on()
  220. ax.set_ylabel('EAR')
  221. ax.set_xlabel('Frame', fontsize ='14')
  222. if save_path != None:
  223. plt.savefig(save_path)
  224. if show == False:
  225. plt.close()
  226. else:
  227. plt.show()
  228. def plot_fpm2(ear_r: np.ndarray, ear_l: np.ndarray, match_indices_r: np.ndarray, match_indices_l: np.ndarray, m: int, th: float, xmin=0, xmax=20000, save_path=None, show=False):
  229. """Plot fast pattern matching, right and left."""
  230. fig, axs = plt.subplots(2, figsize=(20, 6), sharex=True, gridspec_kw={'hspace': 0.1})
  231. plt.suptitle('Fast Pattern Matching', fontsize='20')
  232. axs[0].plot(ear_r, c='r')
  233. axs[1].plot(ear_l, c='b')
  234. for i, match_idx in enumerate(match_indices_r):
  235. axs[0].axvspan(match_idx, match_idx + m, 0, 1, facecolor="lightgrey")
  236. #axs[0].text(match_idx, 0.3, str(i+1), color="black", fontsize=20)
  237. for i, match_idx in enumerate(match_indices_l):
  238. axs[1].axvspan(match_idx, match_idx + m, 0, 1, facecolor="lightgrey")
  239. #axs[1].text(match_indices_l[i], 0, str(i+1), color="black", fontsize=20)
  240. axs[0].set(xlim=[xmin, xmax])
  241. axs[0].minorticks_on()
  242. axs[0].set_ylabel('EAR right')
  243. axs[1].set_ylabel('EAR left')
  244. axs[1].set_xlabel('Frame')
  245. if save_path != None:
  246. plt.savefig(save_path)
  247. if show == False:
  248. plt.close()
  249. else:
  250. plt.show()
  251. # histogram
  252. def plot_prom_hist(proms, eye="right", rel_freq=True, xmin=0, xmax=0.5, ymin=0, ymax=0.3, save_path=None):
  253. """Plot histogram for EAR promineces (both eyes)."""
  254. if eye == "right":
  255. color = "r"
  256. else:
  257. color = "b"
  258. n_bins = int(np.sqrt(len(proms)))
  259. fig, ax = plt.subplots()
  260. ax.set_title(f"Histogram for EAR prominence ({eye} eye)")
  261. ax.set_xlabel("EAR Prominence")
  262. ax.set(xlim=[xmin, xmax], ylim=[ymin, ymax])
  263. if rel_freq == True:
  264. ax.hist(proms, bins=n_bins, edgecolor="white", weights=np.ones_like(proms) / len(proms), color=color)
  265. ax.set_ylabel("Relative Frequency")
  266. else:
  267. ax.hist(proms, bins=n_bins, edgecolor="white", color=color)
  268. ax.set_ylabel("Frequency")
  269. if save_path != None:
  270. # plt.savefig(f"./outputs/histogram/histogram_m{m}_th{th}_bin{n_bins}_r.png")
  271. plt.savefig(save_path)