Przeglądaj źródła

Rename function parameter

Tim Büchner 1 rok temu
rodzic
commit
fa68c6fabc
2 zmienionych plików z 19 dodań i 25 usunięć
  1. 13 22
      sample_experiments/unsupervised_extractin.ipynb
  2. 6 3
      src/ebpm/match.py

+ 13 - 22
sample_experiments/unsupervised_extractin.ipynb

@@ -154,25 +154,16 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 5,
+   "execution_count": 19,
    "metadata": {},
-   "outputs": [
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "/home/buechner/anaconda3/envs/ebpm/lib/python3.10/site-packages/numba/np/ufunc/parallel.py:371: NumbaWarning: The TBB threading layer requires TBB version 2021 update 6 or later i.e., TBB_INTERFACE_VERSION >= 12060. Found TBB_INTERFACE_VERSION = 12050. The TBB threading layer is disabled.\n",
-      "  warnings.warn(problem)\n"
-     ]
-    }
-   ],
+   "outputs": [],
    "source": [
     "candidates = ebpm.unsupervised.extract_candidates(np.concatenate([ear_l, ear_r]))"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 6,
+   "execution_count": 20,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -181,7 +172,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 7,
+   "execution_count": 21,
    "metadata": {},
    "outputs": [
     {
@@ -202,7 +193,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 8,
+   "execution_count": 22,
    "metadata": {},
    "outputs": [
     {
@@ -230,17 +221,17 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 9,
+   "execution_count": 38,
    "metadata": {},
    "outputs": [],
    "source": [
-    "matches_l = ebpm.match.find_prototype(ear_l, prototype)\n",
-    "matches_r = ebpm.match.find_prototype(ear_r, prototype)"
+    "matches_l = ebpm.match.find_prototype(ear_l, prototype, max_prototype_distance=3.0)\n",
+    "matches_r = ebpm.match.find_prototype(ear_r, prototype, max_prototype_distance=3.0)"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 10,
+   "execution_count": 39,
    "metadata": {},
    "outputs": [
     {
@@ -258,7 +249,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 11,
+   "execution_count": 40,
    "metadata": {},
    "outputs": [
     {
@@ -324,7 +315,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 16,
+   "execution_count": 41,
    "metadata": {},
    "outputs": [
     {
@@ -368,7 +359,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 17,
+   "execution_count": 42,
    "metadata": {},
    "outputs": [
     {
@@ -412,7 +403,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 18,
+   "execution_count": 43,
    "metadata": {},
    "outputs": [
     {

+ 6 - 3
src/ebpm/match.py

@@ -3,21 +3,24 @@ __all__ = ["find_prototype", "match_found_intervals"]
 import numpy as np
 import stumpy
 
-def find_prototype(ear_ts: np.ndarray, prototype: np.ndarray, th=3.0):
+def find_prototype(
+    ear_ts: np.ndarray, 
+    prototype: np.ndarray, 
+    max_prototype_distance: float=3.0):
     """
     Find occurrences of a prototype pattern within a time series.
 
     Parameters:
     ear_ts (np.ndarray): The time series to search for occurrences of the prototype pattern.
     prototype (np.ndarray): The prototype pattern to search for within the time series.
-    th (float, optional): The threshold value used to determine matches. Defaults to 3.0.
+    max_prototype_distance (float, optional): The threshold value used to determine matches. Defaults to 3.0.
 
     Returns:
     list: A list of intervals where the prototype pattern is found in the time series.
           Each interval is represented as [from, to, distance_to_prototype].
     """
     def threshold(D):
-        return np.nanmax([np.nanmean(D) - th * np.std(D), np.nanmin(D)])
+        return np.nanmax([np.nanmean(D) - max_prototype_distance * np.std(D), np.nanmin(D)])
     
     matches = stumpy.match(prototype, ear_ts, max_distance=threshold)
     # sort the matches by index to get the original order