Browse Source

Update unsupervised experiments

- rewrite code as package
- update functions names and docstrings
- update examples
Tim Büchner 1 year ago
parent
commit
79611ceebd
7 changed files with 399 additions and 0 deletions
  1. 11 0
      README.md
  2. 133 0
      sample_experiments/unsupervised_extractin.ipynb
  3. 1 0
      setup.cfg
  4. 9 0
      src/ebpm/__init__.py
  5. 84 0
      src/ebpm/match.py
  6. 127 0
      src/ebpm/plot.py
  7. 34 0
      src/ebpm/unsupervised.py

+ 11 - 0
README.md

@@ -0,0 +1,11 @@
+# EBPM
+
+## Setup
+
+```bash
+conda create -n ebpm python=3.10 -y
+conda activate ebpm
+# conda install cudatoolkit -y
+pip install jupyter
+pip install -e .
+```

File diff suppressed because it is too large
+ 133 - 0
sample_experiments/unsupervised_extractin.ipynb


+ 1 - 0
setup.cfg

@@ -26,6 +26,7 @@ install_requires =
     numpy>=1.23,
     pandas>=2.2
     scipy>=1.12
+    numba>=0.59
     stumpy>=1.12
     matplotlib>=3.5,<4
 

+ 9 - 0
src/ebpm/__init__.py

@@ -0,0 +1,9 @@
+__all__ = ["plot", "unsupervised", "match"]
+
+# set env variable to suppress the gpu usage!
+# it is slower but it works on all systemsi
+# NUMBA_DISABLE_CUDA=1
+import os
+os.environ['NUMBA_DISABLE_CUDA'] = '1'
+
+from ebpm import plot, unsupervised, match

+ 84 - 0
src/ebpm/match.py

@@ -0,0 +1,84 @@
+__all__ = ["find_prototype", "match_found_intervals"]
+
+import numpy as np
+import stumpy
+
+def find_prototype(ear_ts: np.ndarray, prototype: np.ndarray, th=3.0):
+    """
+    Find occurrences of a prototype pattern within a time series.
+
+    Parameters:
+    ear_ts (np.ndarray): The time series to search for occurrences of the prototype pattern.
+    prototype (np.ndarray): The prototype pattern to search for within the time series.
+    th (float, optional): The threshold value used to determine matches. Defaults to 3.0.
+
+    Returns:
+    list: A list of intervals where the prototype pattern is found in the time series.
+          Each interval is represented as [from, to, distance_to_prototype].
+    """
+    def threshold(D):
+        return np.nanmax([np.nanmean(D) - th * np.std(D), np.nanmin(D)])
+    
+    matches = stumpy.match(prototype, ear_ts, max_distance=threshold)
+    # sort the matches by index to get the original order
+    matches = sorted(matches, key=lambda x: x[1])
+    
+    intervals = []
+    # transform the matches to be [from, to, distance_to_prototype]
+    for match in matches:
+        intervals.append([match[1], match[1] + len(prototype), match[0]])
+    
+    return intervals
+
+def describe(matches: np.ndarray):
+    """
+    Prints information about the given matches array.
+    
+    Parameters:
+        matches (np.ndarray): An array of matches.
+    
+    Returns:
+        None
+    """
+    print(f"Contains {len(matches)} matches")
+    print(f"Matches: {matches}")
+
+def index_matching(
+    matches_l: np.ndarray,
+    matches_r: np.ndarray,
+    max_match_distance: int=50,
+) -> tuple[list[np.ndarray], list[np.ndarray]]:
+    """
+    Perform index matching between two arrays of matches.
+
+    Args:
+        matches_l (np.ndarray): Array of matches for the left side.
+        matches_r (np.ndarray): Array of matches for the right side.
+        max_match_distance (int, optional): Maximum distance allowed between matches. Defaults to 50.
+
+    Returns:
+        tuple[list[np.ndarray], list[np.ndarray]]: A tuple containing two lists of matches, 
+        where the first list corresponds to the matches from the left side and the second list 
+        corresponds to the matches from the right side.
+
+    """
+    start_idx_l = np.array(matches_l)[:, 0]
+    start_idx_r = np.array(matches_r)[:, 0]
+    
+    pairs = []
+    
+    for i, idx_l in enumerate(start_idx_l):
+        dists = np.abs(start_idx_r - idx_l)
+        
+        min_dist = np.min(dists)
+        min_argl = np.argmin(dists)
+        
+        if min_dist < max_match_distance:
+            # when there are two equal-dist matches, always keep the first one
+            pairs.append([i, min_argl])
+            
+    n_matches_l = [matches_l[i] for i, _ in pairs]
+    n_matches_r = [matches_r[j] for _, j in pairs]
+    
+    # TODO return also the non-matches        
+    return n_matches_l, n_matches_r

+ 127 - 0
src/ebpm/plot.py

@@ -0,0 +1,127 @@
+__all__ = ["ear_time_series", "candidates_overview", "matches"]
+
+import matplotlib.pyplot as plt
+import numpy as np
+from matplotlib import figure
+
+
+def ear_time_series(
+    ear_r: np.ndarray,
+    ear_l: np.ndarray, 
+    xmin: int | None = None, 
+    xmax: int | None = None,
+) -> tuple[figure.Figure, np.ndarray]:
+    """
+    Plot the EAR (Eye Aspect Ratio) time series for the right and left eyes.
+    This is a helper function to visualize the EAR values over time.
+
+    Parameters:
+    - ear_r (np.ndarray): Array containing the EAR values for the right eye.
+    - ear_l (np.ndarray): Array containing the EAR values for the left eye.
+    - xmin (int | None): Minimum x-axis value. If None, the minimum value will be determined automatically.
+    - xmax (int | None): Maximum x-axis value. If None, the maximum value will be determined automatically.
+
+    Returns:
+    - fig (figure.Figure): The matplotlib Figure object containing the plotted time series.
+    """
+    # input validation
+    if not isinstance(ear_r, np.ndarray):
+        raise TypeError("ear_r must be a numpy array")
+    if not isinstance(ear_l, np.ndarray):
+        raise TypeError("ear_l must be a numpy array")
+    if ear_r.ndim != 1:
+        raise ValueError("ear_r must be a 1D array")
+    if ear_l.ndim != 1:
+        raise ValueError("ear_l must be a 1D array")
+    
+    if xmin is not None and not isinstance(xmin, int):
+        raise TypeError("xmin must be an integer")
+    if xmax is not None and not isinstance(xmax, int):
+        raise TypeError("xmax must be an integer")
+    
+    if xmin is None:
+        xmin = 0
+
+    if xmax is None:
+        xmax = max(len(ear_r), len(ear_l))
+    elif xmax > max(len(ear_r), len(ear_l)):
+        xmax = max(len(ear_r), len(ear_l))
+
+    fig, axs = plt.subplots(nrows=2, ncols=1, figsize=(20, 6), sharex=True, sharey=True, gridspec_kw={'hspace': 0.1})
+    
+    axs[0].plot(ear_r, c='red',  label='Right Eye [EAR Value]')
+    axs[1].plot(ear_l, c='blue', label='Left Eye [EAR Value]')
+
+    axs[0].minorticks_on()
+    axs[0].set_xlim([xmin, xmax])
+    axs[1].set_xlabel('Frame [#]', fontsize="18")
+    
+    axs[0].set_ylabel('Right Eye [EAR Value]', fontsize="12")
+    axs[1].set_ylabel('Left Eye [EAR Value]',  fontsize="12")
+    
+    fig.suptitle("Eye Aspect Ratio [EAR] Time Series", fontsize="20")
+    return fig, axs
+
+    
+def candidates_overview(candidates: list[np.ndarray] | np.ndarray) -> figure.Figure:
+    """
+    Generate a figure showing the overview of candidates' EAR time series.
+    
+    Parameters:
+        candidates (list[np.ndarray] | np.ndarray): A list of numpy arrays or a single numpy array representing the EAR time series of candidates.
+        
+    Returns:
+        figure.Figure: The generated matplotlib figure object.
+        
+    Raises:
+        TypeError: If candidates is not a list of numpy arrays.
+    """
+    if isinstance(candidates, np.ndarray):
+        candidates = [candidates]
+    if not isinstance(candidates, list):
+        raise TypeError("candidates must be a list of numpy arrays")
+    
+    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 7))
+    
+    for idx, candidate in enumerate(candidates):
+        ax.plot(candidate, label=f"Candidate {idx+1}")
+        
+    ax.set_xlabel('Frame [#]', fontsize="18")
+    ax.set_ylabel('EAR Value', fontsize="18")
+    
+    ax.set_ylim([0, 0.5])
+    
+    ax.legend()
+    fig.suptitle(f"Top {len(candidates)} candidates from EAR time series", fontsize="20")
+    return fig
+
+    
+def matches(
+    ear_l: np.ndarray,
+    ear_r: np.ndarray,
+    matches_l: list[np.ndarray],
+    matches_r: list[np.ndarray],
+    xmin: int | None = None,
+    xmax: int | None = None,
+) -> figure.Figure:
+    """
+    Plot the ear time series with highlighted matches.
+
+    Args:
+        ear_l (np.ndarray): Left ear time series.
+        ear_r (np.ndarray): Right ear time series.
+        matches_l (list[np.ndarray]): List of matches for the left ear.
+        matches_r (list[np.ndarray]): List of matches for the right ear.
+        xmin (int | None, optional): Minimum x-axis value. Defaults to None.
+        xmax (int | None, optional): Maximum x-axis value. Defaults to None.
+
+    Returns:
+        figure.Figure: The matplotlib figure object.
+    """
+    fig, axs = ear_time_series(ear_r, ear_l, xmin, xmax)
+    
+    for match in matches_l:
+        axs[0].axvspan(match[0], match[1], color='green', alpha=0.3)
+    for match in matches_r:
+        axs[1].axvspan(match[0], match[1], color='green', alpha=0.3)
+    return fig

+ 34 - 0
src/ebpm/unsupervised.py

@@ -0,0 +1,34 @@
+__all__ = ["extract_candidates"]
+
+import numpy as np
+import stumpy
+
+def extract_candidates(
+    ear_ts: np.ndarray, 
+    window_length:int=100, 
+    max_matches:int=10
+) -> list[np.ndarray]:
+    """
+    Extracts candidate motifs from the given time series using the Matrix Profile algorithm.
+
+    Parameters:
+    - ear_ts (np.ndarray): The input time series.
+    - window_length (int): The length of the sliding window used for candidates extraction. Default is 100. Should be based on the FPS of the video.
+    - max_matches (int): The maximum number of candidates to extract. Default is 10.
+
+    Returns:
+    - candidates (list[np.ndarray]): A list of candidate candidates extracted from the time series.
+    """
+    # input validation
+    if not isinstance(ear_ts, np.ndarray):
+        raise TypeError("ear_ts must be a numpy array")
+    if ear_ts.ndim != 1:
+        raise ValueError("ear_ts must be a 1D array")
+    if not isinstance(window_length, int):
+        raise TypeError("window_length must be an integer")
+    if not isinstance(max_matches, int):
+        raise TypeError("max_matches must be an integer")
+    
+    mp = stumpy.stump(ear_ts, window_length) # matrix profile
+    _, candidates_idx = stumpy.motifs(ear_ts, mp[:, 0], max_matches=max_matches)
+    return [ear_ts[idx:idx+window_length] for idx in (candidates_idx[0] + window_length // 2)]

Some files were not shown because too many files changed in this diff