Pārlūkot izejas kodu

sample experiment content added, basic functions added

- prototype extraction
- combined gaussian
- bpm with learned/manual prototype
Yuxuan Xie 1 gadu atpakaļ
vecāks
revīzija
5c9de1188f
2 mainītis faili ar 160 papildinājumiem un 56 dzēšanām
  1. 127 52
      eye_state_prototype.py
  2. 33 4
      sample_experiment.ipynb

+ 127 - 52
eye_state_prototype.py

@@ -4,7 +4,7 @@ import stumpy
 from scipy import signal
 import matplotlib.pyplot as plt
 
-##########Prototype##########
+##########Prototyping##########
 # prototype extraction
 def motif_extraction(ear_ts: np.ndarray, m=100, max_matches=10):
     """Extract top motifs in EAR time series."""
@@ -12,9 +12,15 @@ def motif_extraction(ear_ts: np.ndarray, m=100, max_matches=10):
     motif_distances, motif_indices = stumpy.motifs(ear_ts, mp[:, 0], max_matches=max_matches)
     return motif_distances, motif_indices
 
-# manual prototype definition
+def learn_prototype(ear_ts: np.ndarray, m=100, max_matches=10):
+    """Return motif no.1 as prototype."""
+    _, motif_indices = motif_extraction(ear_ts, m, max_matches)
+    motif_01 = ear_ts[motif_indices[0][0]+m//2:motif_indices[0][0]+m//2+m]
+    return motif_01
+
+# manual definition
 def combined_gaussian(sig1: float, sig2: float, avg: float, prom: float, m=100, mu=40, noise=None):
-    """Manual protype composed of two Gaussians."""
+    """Manual prototype composed of two Gaussians."""
     y1 = - prom * signal.gaussian(2*m, std=sig1) + avg
     y2 = - prom * signal.gaussian(2*m, std=sig2) + avg
     y = np.append(y1[:m], y2[m:])
@@ -24,8 +30,14 @@ def combined_gaussian(sig1: float, sig2: float, avg: float, prom: float, m=100,
     
     return y[m-mu:2*m-mu]
 
-##########Matching##########
-# pattern matching
+def nosie(noise_std: float, m=100):
+    "Random noise based on learned data."
+    np.random.seed(0)
+    noise = (np.random.random(2*100) * 2 - 1) * noise_std
+    return noise
+
+##########Detection##########
+# blink pattern matching
 def fpm(Q: np.ndarray, T: np.ndarray, th=3.0):
     """Fast Pattern Matching"""
     def threshold(D):
@@ -34,20 +46,6 @@ def fpm(Q: np.ndarray, T: np.ndarray, th=3.0):
     # match_indices = matches[:, 1]
     return matches
 
-# simple threholding 
-def find_peaks_in_ear_ts(ear_ts: np.ndarray, h_th=0.15, p_th=None, t_th=None, d_th=50):
-    """
-    Find peaks in EAR time series.
-    h_th = 0.15 # height threshold
-    p_th = None # prominence threshold
-    t_th = None # threshold
-    d_th = 50 # distance threshold
-    """
-    peaks, properties = signal.find_peaks(-ear_ts, height=-h_th, threshold=t_th, prominence=p_th, distance=d_th)
-    heights = - properties["peak_heights"]
-    return peaks, heights
-
-##########Analysis##########
 def index_matching(indices1: np.ndarray, indices2: np.ndarray, max_distance=50):
     "Match indices saved in two arrays."
     matched_pairs = []
@@ -61,26 +59,20 @@ def index_matching(indices1: np.ndarray, indices2: np.ndarray, max_distance=50):
             no_match.append(idx1)
     return np.array(matched_pairs), np.array(no_match)
 
-def get_apex(T: np.ndarray, m: int, match_indices: np.ndarray):
-    """Estimated apex in each extracted matches."""
-    apex_indices = []
-    apex_proms = []
-    for idx in match_indices:
-        apex_prom = np.max(T[idx:idx+m]) - np.min(T[idx:idx+m])
-        apex_idx = idx + np.argmin(T[idx:idx+m])
-        apex_indices.append(apex_idx)
-        apex_proms.append(apex_prom)
-    return np.array(apex_indices), np.array(apex_proms)
-
-def get_stats(diff: np.ndarray) -> dict:
-    """Get statistics (avg, std, median) and save them in a dict."""
-    diff_stats = dict()
-    diff_stats["avg"] = np.mean(diff)
-    diff_stats["std"] = np.std(diff)
-    diff_stats["median"] = np.median(diff)
-    return diff_stats
+# find peaks
+def find_peaks_in_ear_ts(ear_ts: np.ndarray, h_th=0.15, p_th=None, t_th=None, d_th=50):
+    """
+    Find peaks in EAR time series.
+    h_th = 0.15 # height threshold
+    p_th = None # prominence threshold
+    t_th = None # threshold
+    d_th = 50 # distance threshold
+    """
+    peaks, properties = signal.find_peaks(-ear_ts, height=-h_th, threshold=t_th, prominence=p_th, distance=d_th)
+    heights = - properties["peak_heights"]
+    return peaks, heights
 
-def cal_results(ear_r: np.ndarray, ear_l: np.ndarray, prototype, fpm_th=3.0, h_th=0.15, p_th=None, t_th=None, d_th=50, save_path=None):
+def cal_bpm_results(ear_r: np.ndarray, ear_l: np.ndarray, prototype, fpm_th=3.0, h_th=0.15, p_th=None, t_th=None, d_th=50, save_path=None):
     """Caculate and save find peaks and fast pattern matching results."""
     # fast pattern matching
     m = len(prototype)
@@ -95,8 +87,8 @@ def cal_results(ear_r: np.ndarray, ear_l: np.ndarray, prototype, fpm_th=3.0, h_t
     matched_pairs, no_match = index_matching(sorted_indices_r, sorted_indices_l, max_distance=50)
 
     # find peaks
-    peaks_r, heights_r = find_peaks_in_ear_ts(-ear_r, height=-h_th, threshold=t_th, prominence=p_th, distance=d_th)
-    peaks_l, heights_l = find_peaks_in_ear_ts(-ear_l, height=-h_th, threshold=t_th, prominence=p_th, distance=d_th)
+    peaks_r, heights_r = find_peaks_in_ear_ts(ear_r, h_th, t_th, p_th, d_th)
+    peaks_l, heights_l = find_peaks_in_ear_ts(ear_l, h_th, t_th, p_th, d_th)
 
     # save results
     results = {}
@@ -119,7 +111,41 @@ def cal_results(ear_r: np.ndarray, ear_l: np.ndarray, prototype, fpm_th=3.0, h_t
     
     return results
 
-# visulization
+##########Analysis##########
+def get_apex(T: np.ndarray, m: int, match_indices: np.ndarray):
+    """Estimated apex in each extracted matches."""
+    apex_indices = []
+    apex_proms = []
+    for idx in match_indices:
+        apex_prom = np.max(T[idx:idx+m]) - np.min(T[idx:idx+m])
+        apex_idx = idx + np.argmin(T[idx:idx+m])
+        apex_indices.append(apex_idx)
+        apex_proms.append(apex_prom)
+    return np.array(apex_indices), np.array(apex_proms)
+
+def get_stats(diff: np.ndarray) -> dict:
+    """Get statistics (avg, std, median) and save them in a dict."""
+    diff_stats = dict()
+    diff_stats["avg"] = np.mean(diff)
+    diff_stats["std"] = np.std(diff)
+    diff_stats["median"] = np.median(diff)
+    return diff_stats
+
+# other utilities
+def smooth(data: np.ndarray, window_len=5, window="flat"):
+    "Function for smoothing the data. For now, window type: the moving average (flat)."
+    if data.ndim != 1:
+        raise ValueError("Only accept 1D array as input.")
+    
+    if data.size < window_len:
+        raise ValueError("The input data should be larger than the window size.")
+    
+    if window == "flat":
+        kernel = np.ones(window_len) / window_len
+        s = np.convolve(data, kernel, mode="same")
+    return s
+
+##########Analysis##########
 def plot_ear(ear_r: np.ndarray, ear_l: np.ndarray, xmin=0, xmax=20000):
     """Plot right and left ear score."""
     fig, axs = plt.subplots(2, figsize=(20, 6), sharex=True, sharey=True, gridspec_kw={'hspace': 0})
@@ -151,16 +177,65 @@ def plot_mp(ts: np.ndarray, mp: np.ndarray):
     plt.show()
     return True
 
-# other utilities
-def smooth(data: np.ndarray, window_len=5, window="flat"):
-    "Function for smoothing the data. For now, window type: the moving average (flat)."
-    if data.ndim != 1:
-        raise ValueError("Only accept 1D array as input.")
+def plot_zoom_in(ax, ear, sorted_indices, peaks, heights, subregion, zoom_in_box, m=100):
+    """Zoom in subregion on the orginial plot."""
+    x1, x2, y1, y2 = subregion
+    x_in, y_in, w_in, h_in = zoom_in_box
+    axin = ax.inset_axes([x_in, y_in, w_in, h_in], 
+                        xlim=(x1, x2), ylim=(y1, y2), 
+                        xticklabels=[], yticklabels=[])
+    for i, match_idx in enumerate(sorted_indices):
+        axin.axvspan(match_idx, match_idx + m, 0, 1, facecolor="lightgrey")
+        axin.plot(ear, c="r", zorder=1)
+        axin.scatter(peaks, heights, marker='x', zorder=2)
+        axin.set_xticks([])
+        axin.set_yticks([])
+
+    ax.indicate_inset_zoom(axin, edgecolor="black")
+
+def plot_results(ax, ears, results, side, m=100, h_th=0.15, xmin=0, xmax=10000, ymin=-0.1, ymax=0.5, zoom_in_params=None):
+    """Plot fast pattern matching vs simple thresholding results for each EAR time series."""
+    # set values
+    if side == 'right':
+        c = 'r'
+        ear = ears[0]
+        sorted_indices = results["sorted_indices_r"]
+        peaks = results["peaks_r"]
+        heights = results["heights_r"]
+    else:
+        c = 'b'
+        ear = ears[1]
+        sorted_indices = results["sorted_indices_l"]
+        peaks = results["peaks_l"]
+        heights = results["heights_l"]
     
-    if data.size < window_len:
-        raise ValueError("The input data should be larger than the window size.")
+    # EAR time series
+    ax.plot(ear, c=c, zorder=1)
+ 
+    # find peaks
+    ax.hlines(h_th, xmin, xmax, linestyles='dashed', zorder=0) # showing threshold
+    ax.scatter(peaks, heights, marker='x', zorder=2)
+
+    # fpm detected regions
+    for i, match_idx in enumerate(sorted_indices):
+        ax.axvspan(match_idx, match_idx + m, 0, 1, facecolor="lightgrey", zorder=-1)
     
-    if window == "flat":
-        kernel = np.ones(window_len) / window_len
-        s = np.convolve(data, kernel, mode="same")
-    return s
+    # show numbering
+    plot_indices = sorted_indices [(sorted_indices > xmin) & (sorted_indices < xmax)]
+    for j in range(len(plot_indices)):
+        if np.diff(plot_indices)[j-1] <= 200:
+            ax.text(plot_indices[j]+50, 0.38, str(j+1), fontsize=10)
+        else:
+            ax.text(plot_indices[j]-100, 0.38, str(j+1), fontsize=10)
+    
+    if zoom_in_params != None:
+        subregion, zoom_in_box = zoom_in_params
+        plot_zoom_in(ax, ear, sorted_indices, peaks, heights, subregion, zoom_in_box)
+
+    ax.set(xlim=[xmin, xmax], ylim=[ymin, ymax])
+    if side == 'right':
+        ax.set_xticks([])
+    else:
+        ax.set_xlabel("Frame", fontsize=10)
+
+

Failā izmaiņas netiks attēlotas, jo tās ir par lielu
+ 33 - 4
sample_experiment.ipynb


Daži faili netika attēloti, jo izmaiņu fails ir pārāk liels