Procházet zdrojové kódy

add other visualization functions

Yuxuan Xie před 1 rokem
rodič
revize
5aacb20f96
1 změnil soubory, kde provedl 79 přidání a 0 odebrání
  1. 79 0
      eye_state_prototype.py

+ 79 - 0
eye_state_prototype.py

@@ -238,4 +238,83 @@ def plot_results(ax, ears, results, side, m=100, h_th=0.15, xmin=0, xmax=10000,
     else:
         ax.set_xlabel("Frame", fontsize=10)
 
+# other plotting functions for visualization
+def plot_fpm(T: np.ndarray, match_indices: np.ndarray, m: int, th: float, xmin=0, xmax=20000, save_path=None, show=False):
+    """Plot fast pattern matching, one graph."""
+    fig, ax = plt.subplots(figsize=(20, 6), sharex=True, gridspec_kw={'hspace': 0})
+    plt.suptitle('Fast Pattern Matching', fontsize='20')
 
+    ax.plot(T)
+    ax.set_ylabel('EAR', fontsize='14')
+
+    for i, match_idx in enumerate(match_indices):
+        ax.axvspan(match_indices[i], match_indices[i] + m, 0, 1, facecolor="lightgrey")
+        #ax.text(match_indices[i], 0, str(i+1), color="black", fontsize=20)
+
+    ax.set(xlim=[xmin, xmax])
+    ax.minorticks_on()
+    ax.set_ylabel('EAR')
+    ax.set_xlabel('Frame', fontsize ='14')
+
+    if save_path != None:
+        plt.savefig(save_path)
+
+    if show == False:
+        plt.close()
+    else:
+        plt.show()
+
+def plot_fpm2(ear_r: np.ndarray, ear_l: np.ndarray, match_indices_r: np.ndarray, match_indices_l: np.ndarray, m: int, th: float, xmin=0, xmax=20000, save_path=None, show=False):
+    """Plot fast pattern matching, right and left."""
+    fig, axs = plt.subplots(2, figsize=(20, 6), sharex=True, gridspec_kw={'hspace': 0.1})
+    plt.suptitle('Fast Pattern Matching', fontsize='20')
+
+    axs[0].plot(ear_r, c='r')
+    axs[1].plot(ear_l, c='b')
+
+    for i, match_idx in enumerate(match_indices_r):
+        axs[0].axvspan(match_idx, match_idx + m, 0, 1, facecolor="lightgrey")
+        #axs[0].text(match_idx, 0.3, str(i+1), color="black", fontsize=20)
+
+    for i, match_idx in enumerate(match_indices_l):
+        axs[1].axvspan(match_idx, match_idx + m, 0, 1, facecolor="lightgrey")
+        #axs[1].text(match_indices_l[i], 0, str(i+1), color="black", fontsize=20)
+
+    axs[0].set(xlim=[xmin, xmax])
+    axs[0].minorticks_on()
+    axs[0].set_ylabel('EAR right')
+    axs[1].set_ylabel('EAR left')
+    axs[1].set_xlabel('Frame')
+    
+    if save_path != None:
+        plt.savefig(save_path)
+
+    if show == False:
+        plt.close()
+    else:
+        plt.show()
+
+# histogram
+def plot_prom_hist(proms, eye="right", rel_freq=True, xmin=0, xmax=0.5, ymin=0, ymax=0.3, save_path=None):
+    """Plot histogram for EAR promineces (both eyes)."""
+    if eye == "right":
+        color = "r"
+    else:
+        color = "b"
+        
+    n_bins = int(np.sqrt(len(proms)))
+
+    fig, ax = plt.subplots()
+    ax.set_title(f"Histogram for EAR prominence ({eye} eye)")
+    ax.set_xlabel("EAR Prominence")
+    ax.set(xlim=[xmin, xmax], ylim=[ymin, ymax])
+    if rel_freq == True:
+        ax.hist(proms, bins=n_bins, edgecolor="white", weights=np.ones_like(proms) / len(proms), color=color)
+        ax.set_ylabel("Relative Frequency")
+    else:
+        ax.hist(proms, bins=n_bins, edgecolor="white", color=color)
+        ax.set_ylabel("Frequency")
+    
+    if save_path != None:
+        # plt.savefig(f"./outputs/histogram/histogram_m{m}_th{th}_bin{n_bins}_r.png")
+        plt.savefig(save_path)