Переглянути джерело

update prototype extraction

learn prototype to learn prototypes, return all motifs
Yuxuan Xie 1 рік тому
батько
коміт
1bc86c48cc
2 змінених файлів з 16 додано та 16 видалено
  1. 4 4
      eye_state_prototype.py
  2. 12 12
      sample_experiment.ipynb

+ 4 - 4
eye_state_prototype.py

@@ -12,11 +12,11 @@ def motif_extraction(ear_ts: np.ndarray, m=100, max_matches=10):
     motif_distances, motif_indices = stumpy.motifs(ear_ts, mp[:, 0], max_matches=max_matches)
     motif_distances, motif_indices = stumpy.motifs(ear_ts, mp[:, 0], max_matches=max_matches)
     return motif_distances, motif_indices
     return motif_distances, motif_indices
 
 
-def learn_prototype(ear_ts: np.ndarray, m=100, max_matches=10):
-    """Return motif no.1 as prototype."""
+def learn_prototypes(ear_ts: np.ndarray, m=100, max_matches=10):
+    """Return top motifs."""
     _, motif_indices = motif_extraction(ear_ts, m, max_matches)
     _, motif_indices = motif_extraction(ear_ts, m, max_matches)
-    motif_01 = ear_ts[motif_indices[0][0]+m//2:motif_indices[0][0]+m//2+m]
-    return motif_01
+    motifs = np.array([ear_ts[idx:idx+m] for idx in (motif_indices[0] + m // 2)])
+    return motifs
 
 
 # manual definition
 # manual definition
 def combined_gaussian(sig1: float, sig2: float, avg: float, prom: float, m=100, mu=40, noise=None):
 def combined_gaussian(sig1: float, sig2: float, avg: float, prom: float, m=100, mu=40, noise=None):

+ 12 - 12
sample_experiment.ipynb

@@ -9,7 +9,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 103,
+   "execution_count": 17,
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
     {
     {
@@ -35,7 +35,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 104,
+   "execution_count": 20,
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
    "source": [
    "source": [
@@ -48,14 +48,14 @@
     "import stumpy\n",
     "import stumpy\n",
     "\n",
     "\n",
     "import eye_state_prototype\n",
     "import eye_state_prototype\n",
-    "from eye_state_prototype import motif_extraction, learn_prototype, combined_gaussian\n",
+    "from eye_state_prototype import motif_extraction, learn_prototype, learn_prototypes, combined_gaussian\n",
     "from eye_state_prototype import fpm, find_peaks_in_ear_ts, cal_bpm_results\n",
     "from eye_state_prototype import fpm, find_peaks_in_ear_ts, cal_bpm_results\n",
     "from eye_state_prototype import plot_ear, plot_results"
     "from eye_state_prototype import plot_ear, plot_results"
    ]
    ]
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 105,
+   "execution_count": 3,
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
    "source": [
    "source": [
@@ -71,7 +71,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 106,
+   "execution_count": 4,
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
     {
     {
@@ -90,7 +90,7 @@
        "True"
        "True"
       ]
       ]
      },
      },
-     "execution_count": 106,
+     "execution_count": 4,
      "metadata": {},
      "metadata": {},
      "output_type": "execute_result"
      "output_type": "execute_result"
     }
     }
@@ -123,7 +123,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 107,
+   "execution_count": 5,
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
    "source": [
    "source": [
@@ -133,16 +133,16 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 108,
+   "execution_count": 22,
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
     {
     {
      "data": {
      "data": {
       "text/plain": [
       "text/plain": [
-       "[<matplotlib.lines.Line2D at 0x7fc67da6eec0>]"
+       "[<matplotlib.lines.Line2D at 0x7fa3d1782530>]"
       ]
       ]
      },
      },
-     "execution_count": 108,
+     "execution_count": 22,
      "metadata": {},
      "metadata": {},
      "output_type": "execute_result"
      "output_type": "execute_result"
     },
     },
@@ -159,8 +159,8 @@
    ],
    ],
    "source": [
    "source": [
     "# get motif No.1\n",
     "# get motif No.1\n",
-    "motif_01 = learn_prototype(ear_r[0:20000], m=100, max_matches=10)\n",
-    "plt.plot(motif_01)"
+    "motifs = learn_prototypes(ear_r[0:20000], m=100, max_matches=10)\n",
+    "plt.plot(motifs[0])"
    ]
    ]
   },
   },
   {
   {