eye_state_prototype.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. import numpy as np
  2. import pandas as pd
  3. import stumpy
  4. from scipy import signal
  5. import matplotlib.pyplot as plt
  6. ##########Prototype##########
  7. # prototype extraction
  8. def motif_extraction(ear_ts: np.ndarray, m=100, max_matches=10):
  9. """Extract top motifs in EAR time series."""
  10. mp = stumpy.stump(ear_ts, m) # matrix profile
  11. motif_distances, motif_indices = stumpy.motifs(ear_ts, mp[:, 0], max_matches=max_matches)
  12. return motif_distances, motif_indices
  13. # manual prototype definition
  14. def combined_gaussian(sig1: float, sig2: float, avg: float, prom: float, m=100, mu=40, noise=None):
  15. """Manual protype composed of two Gaussians."""
  16. y1 = - prom * signal.gaussian(2*m, std=sig1) + avg
  17. y2 = - prom * signal.gaussian(2*m, std=sig2) + avg
  18. y = np.append(y1[:m], y2[m:])
  19. if noise is not None:
  20. y = y + noise
  21. return y[m-mu:2*m-mu]
  22. ##########Matching##########
  23. # pattern matching
  24. def fpm(Q: np.ndarray, T: np.ndarray, th=3.0):
  25. """Fast Pattern Matching"""
  26. def threshold(D):
  27. return np.nanmax([np.nanmean(D) - th * np.std(D), np.nanmin(D)])
  28. matches = stumpy.match(Q, T, max_distance=threshold)
  29. # match_indices = matches[:, 1]
  30. return matches
  31. # simple threholding
  32. def find_peaks_in_ear_ts(ear_ts: np.ndarray, h_th=0.15, p_th=None, t_th=None, d_th=50):
  33. """
  34. Find peaks in EAR time series.
  35. h_th = 0.15 # height threshold
  36. p_th = None # prominence threshold
  37. t_th = None # threshold
  38. d_th = 50 # distance threshold
  39. """
  40. peaks, properties = signal.find_peaks(-ear_ts, height=-h_th, threshold=t_th, prominence=p_th, distance=d_th)
  41. heights = - properties["peak_heights"]
  42. return peaks, heights
  43. ##########Analysis##########
  44. def index_matching(indices1: np.ndarray, indices2: np.ndarray, max_distance=50):
  45. "Match indices saved in two arrays."
  46. matched_pairs = []
  47. no_match = []
  48. for idx1 in indices1:
  49. dists = np.abs(indices2 - idx1)
  50. min_dist = np.min(dists)
  51. if min_dist < max_distance:
  52. matched_pairs.append([idx1, indices2[np.argmin(dists)]]) # when there are two equal-dist matches, always keep the first onr
  53. else:
  54. no_match.append(idx1)
  55. return np.array(matched_pairs), np.array(no_match)
  56. def get_apex(T: np.ndarray, m: int, match_indices: np.ndarray):
  57. """Estimated apex in each extracted matches."""
  58. apex_indices = []
  59. apex_proms = []
  60. for idx in match_indices:
  61. apex_prom = np.max(T[idx:idx+m]) - np.min(T[idx:idx+m])
  62. apex_idx = idx + np.argmin(T[idx:idx+m])
  63. apex_indices.append(apex_idx)
  64. apex_proms.append(apex_prom)
  65. return np.array(apex_indices), np.array(apex_proms)
  66. def get_stats(diff: np.ndarray) -> dict:
  67. """Get statistics (avg, std, median) and save them in a dict."""
  68. diff_stats = dict()
  69. diff_stats["avg"] = np.mean(diff)
  70. diff_stats["std"] = np.std(diff)
  71. diff_stats["median"] = np.median(diff)
  72. return diff_stats
  73. def cal_results(ear_r: np.ndarray, ear_l: np.ndarray, prototype, fpm_th=3.0, h_th=0.15, p_th=None, t_th=None, d_th=50, save_path=None):
  74. """Caculate and save find peaks and fast pattern matching results."""
  75. # fast pattern matching
  76. m = len(prototype)
  77. matches_r = fpm(Q=prototype, T=ear_r, th=fpm_th)
  78. matches_l = fpm(Q=prototype, T=ear_l, th=fpm_th)
  79. match_indices_r = matches_r[:, 1]
  80. match_indices_l = matches_l[:, 1]
  81. # index matching
  82. sorted_indices_r = np.sort(match_indices_r)
  83. sorted_indices_l = np.sort(match_indices_l)
  84. matched_pairs, no_match = index_matching(sorted_indices_r, sorted_indices_l, max_distance=50)
  85. # find peaks
  86. peaks_r, heights_r = find_peaks_in_ear_ts(-ear_r, height=-h_th, threshold=t_th, prominence=p_th, distance=d_th)
  87. peaks_l, heights_l = find_peaks_in_ear_ts(-ear_l, height=-h_th, threshold=t_th, prominence=p_th, distance=d_th)
  88. # save results
  89. results = {}
  90. results["match_indices_r"] = matches_r[:, 1]
  91. results["match_values_r"] = matches_r[:, 0]
  92. results["match_indices_l"] = matches_l[:, 1]
  93. results["match_values_l"] = matches_l[:, 0]
  94. results["sorted_indices_r"] = sorted_indices_r
  95. results["sorted_indices_l"] = sorted_indices_l
  96. results["matched_pairs_r"] = matched_pairs[:, 0]
  97. results["matched_pairs_l"] = matched_pairs[:, 1]
  98. results["peaks_r"] = peaks_r
  99. results["heights_r"] = heights_r
  100. results["peaks_l"] = peaks_l
  101. results["heights_l"] = heights_l
  102. if save_path is not None:
  103. results_df = pd.DataFrame({key:pd.Series(value) for key, value in results.items()})
  104. results_df.to_csv(save_path)
  105. return results
  106. # visulization
  107. def plot_ear(ear_r: np.ndarray, ear_l: np.ndarray, xmin=0, xmax=20000):
  108. """Plot right and left ear score."""
  109. fig, axs = plt.subplots(2, figsize=(20, 6), sharex=True, sharey=True, gridspec_kw={'hspace': 0})
  110. axs[0].plot(ear_r, c='r', label='right eye')
  111. axs[0].minorticks_on()
  112. if len(ear_r) < xmax:
  113. xmax = len(ear_r)
  114. axs[0].set_xlim([xmin, xmax])
  115. axs[0].set_title("EAR Time Series", fontsize="30")
  116. axs[0].set_ylabel('right', fontsize="18")
  117. axs[1].plot(ear_l, c='b', label='left eye')
  118. axs[1].set_ylabel('left', fontsize="18")
  119. axs[1].set_xlabel('Frame', fontsize="18")
  120. plt.show()
  121. return True
  122. def plot_mp(ts: np.ndarray, mp: np.ndarray):
  123. """Plot EAR and Matrix Profile."""
  124. fig, axs = plt.subplots(2, figsize=(10, 6), sharex=True, gridspec_kw={'hspace': 0})
  125. plt.suptitle('EAR Score and Matrix Profile', fontsize='30')
  126. axs[0].plot(ts)
  127. axs[0].set_ylabel('EAR', fontsize='20')
  128. axs[0].set(xlim=[0, len(ts)], ylim=[0, 0.4])
  129. axs[0].minorticks_on()
  130. axs[1].set_xlabel('Frame', fontsize ='20')
  131. axs[1].set_ylabel('Matrix Profile', fontsize='20')
  132. axs[1].plot(mp[:, 0])
  133. plt.show()
  134. return True
  135. # other utilities
  136. def smooth(data: np.ndarray, window_len=5, window="flat"):
  137. "Function for smoothing the data. For now, window type: the moving average (flat)."
  138. if data.ndim != 1:
  139. raise ValueError("Only accept 1D array as input.")
  140. if data.size < window_len:
  141. raise ValueError("The input data should be larger than the window size.")
  142. if window == "flat":
  143. kernel = np.ones(window_len) / window_len
  144. s = np.convolve(data, kernel, mode="same")
  145. return s