eye_state_prototype.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. import numpy as np
  2. import pandas as pd
  3. import stumpy
  4. from scipy import signal
  5. def fpm(Q: np.ndarray, T: np.ndarray, th=3.0):
  6. """Fast Pattern Matching"""
  7. def threshold(D):
  8. return np.nanmax([np.nanmean(D) - th * np.std(D), np.nanmin(D)])
  9. matches = stumpy.match(Q, T, max_distance=threshold)
  10. # match_indices = matches[:, 1]
  11. return matches
  12. def index_matching(indices1: np.ndarray, indices2: np.ndarray, max_distance=50):
  13. "Match indices saved in two arrays."
  14. matched_pairs = []
  15. no_match = []
  16. for idx1 in indices1:
  17. dists = np.abs(indices2 - idx1)
  18. min_dist = np.min(dists)
  19. if min_dist < max_distance:
  20. matched_pairs.append([idx1, indices2[np.argmin(dists)]]) # when there are two equal-dist matches, always keep the first onr
  21. else:
  22. no_match.append(idx1)
  23. return np.array(matched_pairs), np.array(no_match)
  24. def get_apex(T: np.ndarray, m: int, match_indices: np.ndarray):
  25. """Estimated apex in each extracted matches."""
  26. apex_indices = []
  27. apex_proms = []
  28. for idx in match_indices:
  29. apex_prom = np.max(T[idx:idx+m]) - np.min(T[idx:idx+m])
  30. apex_idx = idx + np.argmin(T[idx:idx+m])
  31. apex_indices.append(apex_idx)
  32. apex_proms.append(apex_prom)
  33. return np.array(apex_indices), np.array(apex_proms)
  34. def get_stats(diff: np.ndarray) -> dict:
  35. """Get statistics (avg, std, median) and save them in a dict."""
  36. diff_stats = dict()
  37. diff_stats["avg"] = np.mean(diff)
  38. diff_stats["std"] = np.std(diff)
  39. diff_stats["median"] = np.median(diff)
  40. return diff_stats
  41. def combined_gaussian(sig1, sig2, avg, prom):
  42. """Manual Protype composed of two Gaussians."""
  43. y1 = - prom * signal.gaussian(200, std=sig1) + avg
  44. y2 = - prom * signal.gaussian(200, std=sig2) + avg
  45. y = np.append(y1[:100], y2[100:])
  46. return y[60:160]
  47. def smooth(data: np.ndarray, window_len=5, window="flat"):
  48. "Function for smoothing the data. For now, window type: the moving average (flat)."
  49. if data.ndim != 1:
  50. raise ValueError("Only accept 1D array as input.")
  51. if data.size < window_len:
  52. raise ValueError("The input data should be larger than the window size.")
  53. if window == "flat":
  54. kernel = np.ones(window_len) / window_len
  55. s = np.convolve(data, kernel, mode="same")
  56. return s
  57. def cal_results(ear_r, ear_l, prototype, save_path=None):
  58. """Caculate and save find peaks and fast pattern matching results."""
  59. # find peaks
  60. h_th = 0.15 # height threshold
  61. p_th = None # prominence threshold
  62. t_th = None # threshold
  63. d_th = 50 # distance threshold
  64. peaks_r, properties_r = signal.find_peaks(-ear_r, height=-h_th, threshold=t_th, prominence=p_th, distance=d_th)
  65. heights_r = - properties_r["peak_heights"]
  66. peaks_l, properties_l = signal.find_peaks(-ear_l, height=-h_th, threshold=-0.1, prominence=p_th, distance=d_th)
  67. heights_l = - properties_l["peak_heights"]
  68. # fast pattern matching
  69. m = len(prototype) # m = 100
  70. fpm_th = 3.0
  71. matches_r = fpm(Q=prototype, T=ear_r, th=fpm_th)
  72. matches_l = fpm(Q=prototype, T=ear_l, th=fpm_th)
  73. match_indices_r = matches_r[:, 1]
  74. match_indices_l = matches_l[:, 1]
  75. # index matching
  76. sorted_indices_r = np.sort(match_indices_r)
  77. sorted_indices_l = np.sort(match_indices_l)
  78. matched_pairs, no_match = index_matching(sorted_indices_r, sorted_indices_l, max_distance=50)
  79. # save results
  80. results = {}
  81. results["match_indices_r"] = matches_r[:, 1]
  82. results["match_values_r"] = matches_r[:, 0]
  83. results["match_indices_l"] = matches_l[:, 1]
  84. results["match_values_l"] = matches_l[:, 0]
  85. results["sorted_indices_r"] = sorted_indices_r
  86. results["sorted_indices_l"] = sorted_indices_l
  87. results["matched_pairs_r"] = matched_pairs[:, 0]
  88. results["matched_pairs_l"] = matched_pairs[:, 1]
  89. results["peaks_r"] = peaks_r
  90. results["heights_r"] = heights_r
  91. results["peaks_l"] = peaks_l
  92. results["heights_l"] = heights_l
  93. if save_path is not None:
  94. results_df = pd.DataFrame({key:pd.Series(value) for key, value in results.items()})
  95. results_df.to_csv(save_path)
  96. return results