|
@@ -4,7 +4,7 @@ import stumpy
|
|
|
from scipy import signal
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
-##########Prototype##########
|
|
|
+##########Prototyping##########
|
|
|
# prototype extraction
|
|
|
def motif_extraction(ear_ts: np.ndarray, m=100, max_matches=10):
|
|
|
"""Extract top motifs in EAR time series."""
|
|
@@ -12,9 +12,15 @@ def motif_extraction(ear_ts: np.ndarray, m=100, max_matches=10):
|
|
|
motif_distances, motif_indices = stumpy.motifs(ear_ts, mp[:, 0], max_matches=max_matches)
|
|
|
return motif_distances, motif_indices
|
|
|
|
|
|
-# manual prototype definition
|
|
|
+def learn_prototype(ear_ts: np.ndarray, m=100, max_matches=10):
|
|
|
+ """Return motif no.1 as prototype."""
|
|
|
+ _, motif_indices = motif_extraction(ear_ts, m, max_matches)
|
|
|
+ motif_01 = ear_ts[motif_indices[0][0]+m//2:motif_indices[0][0]+m//2+m]
|
|
|
+ return motif_01
|
|
|
+
|
|
|
+# manual definition
|
|
|
def combined_gaussian(sig1: float, sig2: float, avg: float, prom: float, m=100, mu=40, noise=None):
|
|
|
- """Manual protype composed of two Gaussians."""
|
|
|
+ """Manual prototype composed of two Gaussians."""
|
|
|
y1 = - prom * signal.gaussian(2*m, std=sig1) + avg
|
|
|
y2 = - prom * signal.gaussian(2*m, std=sig2) + avg
|
|
|
y = np.append(y1[:m], y2[m:])
|
|
@@ -24,8 +30,14 @@ def combined_gaussian(sig1: float, sig2: float, avg: float, prom: float, m=100,
|
|
|
|
|
|
return y[m-mu:2*m-mu]
|
|
|
|
|
|
-##########Matching##########
|
|
|
-# pattern matching
|
|
|
+def nosie(noise_std: float, m=100):
|
|
|
+ "Random noise based on learned data."
|
|
|
+ np.random.seed(0)
|
|
|
+ noise = (np.random.random(2*100) * 2 - 1) * noise_std
|
|
|
+ return noise
|
|
|
+
|
|
|
+##########Detection##########
|
|
|
+# blink pattern matching
|
|
|
def fpm(Q: np.ndarray, T: np.ndarray, th=3.0):
|
|
|
"""Fast Pattern Matching"""
|
|
|
def threshold(D):
|
|
@@ -34,20 +46,6 @@ def fpm(Q: np.ndarray, T: np.ndarray, th=3.0):
|
|
|
# match_indices = matches[:, 1]
|
|
|
return matches
|
|
|
|
|
|
-# simple threholding
|
|
|
-def find_peaks_in_ear_ts(ear_ts: np.ndarray, h_th=0.15, p_th=None, t_th=None, d_th=50):
|
|
|
- """
|
|
|
- Find peaks in EAR time series.
|
|
|
- h_th = 0.15 # height threshold
|
|
|
- p_th = None # prominence threshold
|
|
|
- t_th = None # threshold
|
|
|
- d_th = 50 # distance threshold
|
|
|
- """
|
|
|
- peaks, properties = signal.find_peaks(-ear_ts, height=-h_th, threshold=t_th, prominence=p_th, distance=d_th)
|
|
|
- heights = - properties["peak_heights"]
|
|
|
- return peaks, heights
|
|
|
-
|
|
|
-##########Analysis##########
|
|
|
def index_matching(indices1: np.ndarray, indices2: np.ndarray, max_distance=50):
|
|
|
"Match indices saved in two arrays."
|
|
|
matched_pairs = []
|
|
@@ -61,26 +59,20 @@ def index_matching(indices1: np.ndarray, indices2: np.ndarray, max_distance=50):
|
|
|
no_match.append(idx1)
|
|
|
return np.array(matched_pairs), np.array(no_match)
|
|
|
|
|
|
-def get_apex(T: np.ndarray, m: int, match_indices: np.ndarray):
|
|
|
- """Estimated apex in each extracted matches."""
|
|
|
- apex_indices = []
|
|
|
- apex_proms = []
|
|
|
- for idx in match_indices:
|
|
|
- apex_prom = np.max(T[idx:idx+m]) - np.min(T[idx:idx+m])
|
|
|
- apex_idx = idx + np.argmin(T[idx:idx+m])
|
|
|
- apex_indices.append(apex_idx)
|
|
|
- apex_proms.append(apex_prom)
|
|
|
- return np.array(apex_indices), np.array(apex_proms)
|
|
|
-
|
|
|
-def get_stats(diff: np.ndarray) -> dict:
|
|
|
- """Get statistics (avg, std, median) and save them in a dict."""
|
|
|
- diff_stats = dict()
|
|
|
- diff_stats["avg"] = np.mean(diff)
|
|
|
- diff_stats["std"] = np.std(diff)
|
|
|
- diff_stats["median"] = np.median(diff)
|
|
|
- return diff_stats
|
|
|
+# find peaks
|
|
|
+def find_peaks_in_ear_ts(ear_ts: np.ndarray, h_th=0.15, p_th=None, t_th=None, d_th=50):
|
|
|
+ """
|
|
|
+ Find peaks in EAR time series.
|
|
|
+ h_th = 0.15 # height threshold
|
|
|
+ p_th = None # prominence threshold
|
|
|
+ t_th = None # threshold
|
|
|
+ d_th = 50 # distance threshold
|
|
|
+ """
|
|
|
+ peaks, properties = signal.find_peaks(-ear_ts, height=-h_th, threshold=t_th, prominence=p_th, distance=d_th)
|
|
|
+ heights = - properties["peak_heights"]
|
|
|
+ return peaks, heights
|
|
|
|
|
|
-def cal_results(ear_r: np.ndarray, ear_l: np.ndarray, prototype, fpm_th=3.0, h_th=0.15, p_th=None, t_th=None, d_th=50, save_path=None):
|
|
|
+def cal_bpm_results(ear_r: np.ndarray, ear_l: np.ndarray, prototype, fpm_th=3.0, h_th=0.15, p_th=None, t_th=None, d_th=50, save_path=None):
|
|
|
"""Caculate and save find peaks and fast pattern matching results."""
|
|
|
# fast pattern matching
|
|
|
m = len(prototype)
|
|
@@ -95,8 +87,8 @@ def cal_results(ear_r: np.ndarray, ear_l: np.ndarray, prototype, fpm_th=3.0, h_t
|
|
|
matched_pairs, no_match = index_matching(sorted_indices_r, sorted_indices_l, max_distance=50)
|
|
|
|
|
|
# find peaks
|
|
|
- peaks_r, heights_r = find_peaks_in_ear_ts(-ear_r, height=-h_th, threshold=t_th, prominence=p_th, distance=d_th)
|
|
|
- peaks_l, heights_l = find_peaks_in_ear_ts(-ear_l, height=-h_th, threshold=t_th, prominence=p_th, distance=d_th)
|
|
|
+ peaks_r, heights_r = find_peaks_in_ear_ts(ear_r, h_th, t_th, p_th, d_th)
|
|
|
+ peaks_l, heights_l = find_peaks_in_ear_ts(ear_l, h_th, t_th, p_th, d_th)
|
|
|
|
|
|
# save results
|
|
|
results = {}
|
|
@@ -119,7 +111,41 @@ def cal_results(ear_r: np.ndarray, ear_l: np.ndarray, prototype, fpm_th=3.0, h_t
|
|
|
|
|
|
return results
|
|
|
|
|
|
-# visulization
|
|
|
+##########Analysis##########
|
|
|
+def get_apex(T: np.ndarray, m: int, match_indices: np.ndarray):
|
|
|
+ """Estimated apex in each extracted matches."""
|
|
|
+ apex_indices = []
|
|
|
+ apex_proms = []
|
|
|
+ for idx in match_indices:
|
|
|
+ apex_prom = np.max(T[idx:idx+m]) - np.min(T[idx:idx+m])
|
|
|
+ apex_idx = idx + np.argmin(T[idx:idx+m])
|
|
|
+ apex_indices.append(apex_idx)
|
|
|
+ apex_proms.append(apex_prom)
|
|
|
+ return np.array(apex_indices), np.array(apex_proms)
|
|
|
+
|
|
|
+def get_stats(diff: np.ndarray) -> dict:
|
|
|
+ """Get statistics (avg, std, median) and save them in a dict."""
|
|
|
+ diff_stats = dict()
|
|
|
+ diff_stats["avg"] = np.mean(diff)
|
|
|
+ diff_stats["std"] = np.std(diff)
|
|
|
+ diff_stats["median"] = np.median(diff)
|
|
|
+ return diff_stats
|
|
|
+
|
|
|
+# other utilities
|
|
|
+def smooth(data: np.ndarray, window_len=5, window="flat"):
|
|
|
+ "Function for smoothing the data. For now, window type: the moving average (flat)."
|
|
|
+ if data.ndim != 1:
|
|
|
+ raise ValueError("Only accept 1D array as input.")
|
|
|
+
|
|
|
+ if data.size < window_len:
|
|
|
+ raise ValueError("The input data should be larger than the window size.")
|
|
|
+
|
|
|
+ if window == "flat":
|
|
|
+ kernel = np.ones(window_len) / window_len
|
|
|
+ s = np.convolve(data, kernel, mode="same")
|
|
|
+ return s
|
|
|
+
|
|
|
+##########Analysis##########
|
|
|
def plot_ear(ear_r: np.ndarray, ear_l: np.ndarray, xmin=0, xmax=20000):
|
|
|
"""Plot right and left ear score."""
|
|
|
fig, axs = plt.subplots(2, figsize=(20, 6), sharex=True, sharey=True, gridspec_kw={'hspace': 0})
|
|
@@ -151,16 +177,65 @@ def plot_mp(ts: np.ndarray, mp: np.ndarray):
|
|
|
plt.show()
|
|
|
return True
|
|
|
|
|
|
-# other utilities
|
|
|
-def smooth(data: np.ndarray, window_len=5, window="flat"):
|
|
|
- "Function for smoothing the data. For now, window type: the moving average (flat)."
|
|
|
- if data.ndim != 1:
|
|
|
- raise ValueError("Only accept 1D array as input.")
|
|
|
+def plot_zoom_in(ax, ear, sorted_indices, peaks, heights, subregion, zoom_in_box, m=100):
|
|
|
+ """Zoom in subregion on the orginial plot."""
|
|
|
+ x1, x2, y1, y2 = subregion
|
|
|
+ x_in, y_in, w_in, h_in = zoom_in_box
|
|
|
+ axin = ax.inset_axes([x_in, y_in, w_in, h_in],
|
|
|
+ xlim=(x1, x2), ylim=(y1, y2),
|
|
|
+ xticklabels=[], yticklabels=[])
|
|
|
+ for i, match_idx in enumerate(sorted_indices):
|
|
|
+ axin.axvspan(match_idx, match_idx + m, 0, 1, facecolor="lightgrey")
|
|
|
+ axin.plot(ear, c="r", zorder=1)
|
|
|
+ axin.scatter(peaks, heights, marker='x', zorder=2)
|
|
|
+ axin.set_xticks([])
|
|
|
+ axin.set_yticks([])
|
|
|
+
|
|
|
+ ax.indicate_inset_zoom(axin, edgecolor="black")
|
|
|
+
|
|
|
+def plot_results(ax, ears, results, side, m=100, h_th=0.15, xmin=0, xmax=10000, ymin=-0.1, ymax=0.5, zoom_in_params=None):
|
|
|
+ """Plot fast pattern matching vs simple thresholding results for each EAR time series."""
|
|
|
+ # set values
|
|
|
+ if side == 'right':
|
|
|
+ c = 'r'
|
|
|
+ ear = ears[0]
|
|
|
+ sorted_indices = results["sorted_indices_r"]
|
|
|
+ peaks = results["peaks_r"]
|
|
|
+ heights = results["heights_r"]
|
|
|
+ else:
|
|
|
+ c = 'b'
|
|
|
+ ear = ears[1]
|
|
|
+ sorted_indices = results["sorted_indices_l"]
|
|
|
+ peaks = results["peaks_l"]
|
|
|
+ heights = results["heights_l"]
|
|
|
|
|
|
- if data.size < window_len:
|
|
|
- raise ValueError("The input data should be larger than the window size.")
|
|
|
+ # EAR time series
|
|
|
+ ax.plot(ear, c=c, zorder=1)
|
|
|
+
|
|
|
+ # find peaks
|
|
|
+ ax.hlines(h_th, xmin, xmax, linestyles='dashed', zorder=0) # showing threshold
|
|
|
+ ax.scatter(peaks, heights, marker='x', zorder=2)
|
|
|
+
|
|
|
+ # fpm detected regions
|
|
|
+ for i, match_idx in enumerate(sorted_indices):
|
|
|
+ ax.axvspan(match_idx, match_idx + m, 0, 1, facecolor="lightgrey", zorder=-1)
|
|
|
|
|
|
- if window == "flat":
|
|
|
- kernel = np.ones(window_len) / window_len
|
|
|
- s = np.convolve(data, kernel, mode="same")
|
|
|
- return s
|
|
|
+ # show numbering
|
|
|
+ plot_indices = sorted_indices [(sorted_indices > xmin) & (sorted_indices < xmax)]
|
|
|
+ for j in range(len(plot_indices)):
|
|
|
+ if np.diff(plot_indices)[j-1] <= 200:
|
|
|
+ ax.text(plot_indices[j]+50, 0.38, str(j+1), fontsize=10)
|
|
|
+ else:
|
|
|
+ ax.text(plot_indices[j]-100, 0.38, str(j+1), fontsize=10)
|
|
|
+
|
|
|
+ if zoom_in_params != None:
|
|
|
+ subregion, zoom_in_box = zoom_in_params
|
|
|
+ plot_zoom_in(ax, ear, sorted_indices, peaks, heights, subregion, zoom_in_box)
|
|
|
+
|
|
|
+ ax.set(xlim=[xmin, xmax], ylim=[ymin, ymax])
|
|
|
+ if side == 'right':
|
|
|
+ ax.set_xticks([])
|
|
|
+ else:
|
|
|
+ ax.set_xlabel("Frame", fontsize=10)
|
|
|
+
|
|
|
+
|