فهرست منبع

Add manual prototype

Tim Büchner 1 سال پیش
والد
کامیت
38c49b2fd9
5فایلهای تغییر یافته به همراه808 افزوده شده و 2 حذف شده
  1. 133 0
      sample_experiments/manual_extraction.ipynb
  2. 2 2
      src/ebpm/__init__.py
  3. 534 0
      src/ebpm/curlyBrace.py
  4. 79 0
      src/ebpm/manual.py
  5. 60 0
      src/ebpm/plot.py

تفاوت فایلی نمایش داده نمی شود زیرا این فایل بسیار بزرگ است
+ 133 - 0
sample_experiments/manual_extraction.ipynb


+ 2 - 2
src/ebpm/__init__.py

@@ -1,4 +1,4 @@
-__all__ = ["plot", "unsupervised", "match"]
+__all__ = ["plot", "unsupervised", "match", "manual"]
 
 # set env variable to suppress the gpu usage!
 # it is slower but it works on all systemsi
@@ -6,4 +6,4 @@ __all__ = ["plot", "unsupervised", "match"]
 import os
 os.environ['NUMBA_DISABLE_CUDA'] = '1'
 
-from ebpm import plot, unsupervised, match
+from ebpm import plot, unsupervised, match, manual

+ 534 - 0
src/ebpm/curlyBrace.py

@@ -0,0 +1,534 @@
+# -*- coding: utf-8 -*- 
+
+'''
+Module Name : curlyBrace
+
+Author : 高斯羽 博士 (Dr. GAO, Siyu)
+
+Version : 1.0.2
+
+Last Modified : 2019-04-22
+
+This module is basically an Python implementation of the function written Pål Næverlid Sævik
+for MATLAB (link in Reference).
+
+The function "curlyBrace" allows you to plot an optionally annotated curly bracket between 
+two points when using matplotlib.
+
+The usual settings for line and fonts in matplotlib also apply.
+
+The function takes the axes scales into account automatically. But when the axes aspect is 
+set to "equal", the auto switch should be turned off.
+
+Change Log
+----------------------
+* **Notable changes:**
+    + Version : 1.0.2
+        - Added considerations for different scaled axes and log scale
+    + Version : 1.0.1
+        - First version.
+
+Reference
+----------------------
+https://uk.mathworks.com/matlabcentral/fileexchange/38716-curly-brace-annotation
+
+List of functions
+----------------------
+
+* getAxSize_
+* curlyBrace_
+
+'''
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+def getAxSize(fig, ax):
+    '''
+    .. _getAxSize :
+
+    Get the axes size in pixels.
+
+    Parameters
+    ----------
+    fig : matplotlib figure object
+        The of the target axes.
+
+    ax : matplotlib axes object
+        The target axes.
+
+    Returns
+    -------
+    ax_width : float
+        The axes width in pixels.
+
+    ax_height : float
+        The axes height in pixels.
+
+    Reference
+    -----------
+    https://stackoverflow.com/questions/19306510/determine-matplotlib-axis-size-in-pixels
+    '''
+
+    bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
+    ax_width, ax_height = bbox.width, bbox.height
+    ax_width *= fig.dpi
+    ax_height *= fig.dpi
+
+    return ax_width, ax_height
+
+def curlyBrace(fig, ax, p1, p2, k_r=0.1, bool_auto=True, str_text='', int_line_num=2, fontdict={}, **kwargs):
+# def curlyBrace(fig, ax, p1, p2, k_r=0.1, bool_auto=True, str_text='', int_line_num=2, fontdict={}, **kwargs):
+    '''
+    .. _curlyBrace :
+
+    Plot an optionally annotated curly bracket on the given axes of the given figure.
+
+    Note that the brackets are anti-clockwise by default. To reverse the text position, swap
+    "p1" and "p2".
+
+    Note that, when the axes aspect is not set to "equal", the axes coordinates need to be
+    transformed to screen coordinates, otherwise the arcs may not be seeable. 
+
+    Parameters
+    ----------
+    fig : matplotlib figure object
+        The of the target axes.
+
+    ax : matplotlib axes object
+        The target axes.
+
+    p1 : two element numeric list
+        The coordinates of the starting point.
+
+    p2 : two element numeric list
+        The coordinates of the end point.
+
+    k_r : float
+        This is the gain controlling how "curvy" and "pointy" (height) the bracket is.
+
+        Note that, if this gain is too big, the bracket would be very strange.
+
+    bool_auto : boolean
+        This is a switch controlling wether to use the auto calculation of axes
+        scales.
+
+        When the two axes do not have the same aspects, i.e., not "equal" scales,
+        this should be turned on, i.e., True.
+
+        When "equal" aspect is used, this should be turned off, i.e., False.
+
+        If you do not set this to False when setting the axes aspect to "equal",
+        the bracket will be in funny shape.
+
+        Default = True
+
+    str_text : string
+        The annotation text of the bracket. It would displayed at the mid point
+        of bracket with the same rotation as the bracket.
+
+        By default, it follows the anti-clockwise convention. To flip it, swap 
+        the end point and the starting point.
+
+        The appearance of this string can be set by using "fontdict", which follows
+        the same syntax as the normal matplotlib syntax for font dictionary.
+
+        Default = empty string (no annotation)
+
+    int_line_num : int
+        This argument determines how many lines the string annotation is from the summit
+        of the bracket.
+
+        The distance would be affected by the font size, since it basically just a number of
+        lines appended to the given string.
+
+        Default = 2
+
+    fontdict : dictionary
+        This is font dictionary setting the string annotation. It is the same as normal
+        matplotlib font dictionary.
+
+        Default = empty dict
+
+    **kwargs : matplotlib line setting arguments
+        This allows the user to set the line arguments using named arguments that are
+        the same as in matplotlib.
+
+    Returns
+    -------
+    theta : float
+        The bracket angle in radians.
+
+    summit : list
+        The positions of the bracket summit.
+
+    arc1 : list of lists
+        arc1 positions.
+
+    arc2 : list of lists
+        arc2 positions.
+
+    arc3 : list of lists
+        arc3 positions.
+
+    arc4 : list of lists
+        arc4 positions.
+
+    Reference
+    ----------
+    https://uk.mathworks.com/matlabcentral/fileexchange/38716-curly-brace-annotation
+    '''
+
+    pt1 = [None, None]
+    pt2 = [None, None]
+
+    ax_width, ax_height = getAxSize(fig, ax)
+
+    ax_xlim = list(ax.get_xlim())
+    ax_ylim = list(ax.get_ylim())
+
+    # log scale consideration
+    if 'log' in ax.get_xaxis().get_scale():
+
+        if p1[0] > 0.0:
+
+            pt1[0] = np.log(p1[0])
+
+        elif p1[0] < 0.0:
+
+            pt1[0] = -np.log(abs(p1[0]))
+
+        else:
+
+            pt1[0] = 0.0
+
+        if p2[0] > 0.0:
+
+            pt2[0] = np.log(p2[0])
+
+        elif p2[0] < 0.0:
+
+            pt2[0] = -np.log(abs(p2[0]))
+
+        else:
+
+            pt2[0] = 0
+
+        for i in range(0, len(ax_xlim)):
+
+            if ax_xlim[i] > 0.0:
+
+                ax_xlim[i] = np.log(ax_xlim[i])
+
+            elif ax_xlim[i] < 0.0:
+
+                ax_xlim[i] = -np.log(abs(ax_xlim[i]))
+
+            else:
+
+                ax_xlim[i] = 0.0
+
+    else:
+
+        pt1[0] = p1[0]
+        pt2[0] = p2[0]
+
+    if 'log' in ax.get_yaxis().get_scale():
+
+        if p1[1] > 0.0:
+
+            pt1[1] = np.log(p1[1])
+
+        elif p1[1] < 0.0:
+
+            pt1[1] = -np.log(abs(p1[1]))
+
+        else:
+
+            pt1[1] = 0.0
+
+        if p2[1] > 0.0:
+
+            pt2[1] = np.log(p2[1])
+
+        elif p2[1] < 0.0:
+
+            pt2[1] = -np.log(abs(p2[1]))
+
+        else:
+
+            pt2[1] = 0.0
+
+        for i in range(0, len(ax_ylim)):
+
+            if ax_ylim[i] > 0.0:
+
+                ax_ylim[i] = np.log(ax_ylim[i])
+
+            elif ax_ylim[i] < 0.0:
+
+                ax_ylim[i] = -np.log(abs(ax_ylim[i]))
+
+            else:
+
+                ax_ylim[i] = 0.0
+
+    else:
+
+        pt1[1] = p1[1]
+        pt2[1] = p2[1]
+
+    # get the ratio of pixels/length
+    xscale = ax_width / abs(ax_xlim[1] - ax_xlim[0])
+    yscale = ax_height / abs(ax_ylim[1] - ax_ylim[0])
+
+    # this is to deal with 'equal' axes aspects
+    if bool_auto:
+
+        pass
+
+    else:
+
+        xscale = 1.0
+        yscale = 1.0
+
+    # convert length to pixels, 
+    # need to minus the lower limit to move the points back to the origin. Then add the limits back on end.
+    pt1[0] = (pt1[0] - ax_xlim[0]) * xscale
+    pt1[1] = (pt1[1] - ax_ylim[0]) * yscale
+    pt2[0] = (pt2[0] - ax_xlim[0]) * xscale
+    pt2[1] = (pt2[1] - ax_ylim[0]) * yscale
+
+    # calculate the angle
+    theta = np.arctan2(pt2[1] - pt1[1], pt2[0] - pt1[0])
+
+    # calculate the radius of the arcs
+    r = np.hypot(pt2[0] - pt1[0], pt2[1] - pt1[1]) * k_r
+
+    # arc1 centre
+    x11 = pt1[0] + r * np.cos(theta)
+    y11 = pt1[1] + r * np.sin(theta)
+
+    # arc2 centre
+    x22 = (pt2[0] + pt1[0]) / 2.0 - 2.0 * r * np.sin(theta) - r * np.cos(theta)
+    y22 = (pt2[1] + pt1[1]) / 2.0 + 2.0 * r * np.cos(theta) - r * np.sin(theta)
+
+    # arc3 centre
+    x33 = (pt2[0] + pt1[0]) / 2.0 - 2.0 * r * np.sin(theta) + r * np.cos(theta)
+    y33 = (pt2[1] + pt1[1]) / 2.0 + 2.0 * r * np.cos(theta) + r * np.sin(theta)
+
+    # arc4 centre
+    x44 = pt2[0] - r * np.cos(theta)
+    y44 = pt2[1] - r * np.sin(theta)
+
+    # prepare the rotated
+    q = np.linspace(theta, theta + np.pi/2.0, 50)
+
+    # reverse q
+    # t = np.flip(q) # this command is not supported by lower version of numpy
+    t = q[::-1]
+
+    # arc coordinates
+    arc1x = r * np.cos(t + np.pi/2.0) + x11
+    arc1y = r * np.sin(t + np.pi/2.0) + y11
+
+    arc2x = r * np.cos(q - np.pi/2.0) + x22
+    arc2y = r * np.sin(q - np.pi/2.0) + y22
+
+    arc3x = r * np.cos(q + np.pi) + x33
+    arc3y = r * np.sin(q + np.pi) + y33
+
+    arc4x = r * np.cos(t) + x44
+    arc4y = r * np.sin(t) + y44
+
+    # convert back to the axis coordinates
+    arc1x = arc1x / xscale + ax_xlim[0]
+    arc2x = arc2x / xscale + ax_xlim[0]
+    arc3x = arc3x / xscale + ax_xlim[0]
+    arc4x = arc4x / xscale + ax_xlim[0]
+
+    arc1y = arc1y / yscale + ax_ylim[0]
+    arc2y = arc2y / yscale + ax_ylim[0]
+    arc3y = arc3y / yscale + ax_ylim[0]
+    arc4y = arc4y / yscale + ax_ylim[0]
+
+    # log scale consideration
+    if 'log' in ax.get_xaxis().get_scale():
+
+        for i in range(0, len(arc1x)):
+
+            if arc1x[i] > 0.0:
+
+                arc1x[i] = np.exp(arc1x[i])
+
+            elif arc1x[i] < 0.0:
+
+                arc1x[i] = -np.exp(abs(arc1x[i]))
+
+            else:
+
+                arc1x[i] = 0.0
+
+        for i in range(0, len(arc2x)):
+
+            if arc2x[i] > 0.0:
+
+                arc2x[i] = np.exp(arc2x[i])
+
+            elif arc2x[i] < 0.0:
+
+                arc2x[i] = -np.exp(abs(arc2x[i]))
+
+            else:
+
+                arc2x[i] = 0.0
+
+        for i in range(0, len(arc3x)):
+
+            if arc3x[i] > 0.0:
+
+                arc3x[i] = np.exp(arc3x[i])
+
+            elif arc3x[i] < 0.0:
+
+                arc3x[i] = -np.exp(abs(arc3x[i]))
+
+            else:
+
+                arc3x[i] = 0.0
+
+        for i in range(0, len(arc4x)):
+
+            if arc4x[i] > 0.0:
+
+                arc4x[i] = np.exp(arc4x[i])
+
+            elif arc4x[i] < 0.0:
+
+                arc4x[i] = -np.exp(abs(arc4x[i]))
+
+            else:
+
+                arc4x[i] = 0.0
+
+    else:
+
+        pass
+
+    if 'log' in ax.get_yaxis().get_scale():
+
+        for i in range(0, len(arc1y)):
+
+            if arc1y[i] > 0.0:
+
+                arc1y[i] = np.exp(arc1y[i])
+
+            elif arc1y[i] < 0.0:
+
+                arc1y[i] = -np.exp(abs(arc1y[i]))
+
+            else:
+
+                arc1y[i] = 0.0
+
+        for i in range(0, len(arc2y)):
+
+            if arc2y[i] > 0.0:
+
+                arc2y[i] = np.exp(arc2y[i])
+
+            elif arc2y[i] < 0.0:
+
+                arc2y[i] = -np.exp(abs(arc2y[i]))
+
+            else:
+
+                arc2y[i] = 0.0
+
+        for i in range(0, len(arc3y)):
+
+            if arc3y[i] > 0.0:
+
+                arc3y[i] = np.exp(arc3y[i])
+
+            elif arc3y[i] < 0.0:
+
+                arc3y[i] = -np.exp(abs(arc3y[i]))
+
+            else:
+
+                arc3y[i] = 0.0
+
+        for i in range(0, len(arc4y)):
+
+            if arc4y[i] > 0.0:
+
+                arc4y[i] = np.exp(arc4y[i])
+
+            elif arc4y[i] < 0.0:
+
+                arc4y[i] = -np.exp(abs(arc4y[i]))
+
+            else:
+
+                arc4y[i] = 0.0
+
+    else:
+
+        pass
+
+    # plot arcs
+    ax.plot(arc1x, arc1y, **kwargs)
+    ax.plot(arc2x, arc2y, **kwargs)
+    ax.plot(arc3x, arc3y, **kwargs)
+    ax.plot(arc4x, arc4y, **kwargs)
+
+    # plot lines
+    ax.plot([arc1x[-1], arc2x[1]], [arc1y[-1], arc2y[1]], **kwargs)
+    ax.plot([arc3x[-1], arc4x[1]], [arc3y[-1], arc4y[1]], **kwargs)
+
+    summit = [arc2x[-1], arc2y[-1]]
+
+    if str_text:
+
+        int_line_num = int(int_line_num)
+
+        str_temp = '\n' * int_line_num
+        
+        # convert radians to degree and within 0 to 360
+        ang = np.degrees(theta) % 360.0
+
+        if (ang >= 0.0) and (ang <= 90.0):
+
+            rotation = ang
+
+            str_text = str_text + str_temp
+
+        if (ang > 90.0) and (ang < 270.0):
+
+            rotation = ang + 180.0
+
+            str_text = str_temp + str_text
+
+        elif (ang >= 270.0) and (ang <= 360.0):
+
+            rotation = ang
+
+            str_text = str_text + str_temp
+
+        else:
+
+            rotation = ang
+
+        ax.axes.text(arc2x[-1], arc2y[-1], str_text, ha='center', va='center', rotation=rotation, fontdict=fontdict)
+
+    else:
+
+        pass
+
+    arc1 = [arc1x, arc1y]
+    arc2 = [arc2x, arc2y]
+    arc3 = [arc3x, arc3y]
+    arc4 = [arc4x, arc4y]
+
+    return theta, summit, arc1, arc2, arc3, arc4

+ 79 - 0
src/ebpm/manual.py

@@ -0,0 +1,79 @@
+__all__ = ["define_prototype"]
+
+import numpy as np
+from scipy.signal import gaussian
+
+def define_prototype(
+    sig1: float = 6.3,
+    sig2: float = 13.6,
+    baseline: float = 0.3,
+    prominance: float = 0.25,
+    apex_location:float | None = None,
+    window_size: int=100,
+    noise: bool=False,
+    return_params: bool = False
+) -> np.ndarray | tuple[np.ndarray, tuple[float, float, float, float, int]]:
+    """
+    Define a manual prototype composed of two Gaussians.
+    
+    The default parameters are set to the values used in the paper.
+    They are learned from the data and are used to define the prototype.
+    
+    We compute the onset and offset as 3 * sig1 and 3 * sig2, respectively.
+    If these values are not inside the window, the definition will raise an error.
+    
+    Parameters:
+        sig1 (float): Standard deviation of the first Gaussian.
+        sig2 (float): Standard deviation of the second Gaussian.
+        baseline (float): Baseline value of the prototype.
+        prominance (float): Prominence of the Gaussians.
+        apex_location (int | None): Location of the apex of the prototype. If None, it is set to 40% of the window size.
+        window_size (int): Size of the window.
+        noise (bool): Flag indicating whether to add noise to the prototype.
+        return_params (bool): Flag indicating whether to return the prototype parameters.
+    
+    Returns:
+        numpy.ndarray: The manual prototype.
+    """
+    
+    if apex_location is None:
+        # default 40% of the window size
+        # this how the prototype is defined in the paper
+        apex_location = 0.4
+    apex_location = int(window_size * apex_location)
+    print(apex_location, type(apex_location))
+
+    onset_x  = apex_location - 3 * sig1
+    offset_x = apex_location + 3 * sig2
+    
+    if onset_x < 0 or offset_x > window_size:
+        raise ValueError("Onset and offset are not inside the window. Choose different parameters.")
+
+    # TODO replace with scipy.stats.norm.pdf 
+    y1 = -prominance * gaussian(window_size*2, std=sig1) + baseline
+    y2 = -prominance * gaussian(window_size*2, std=sig2) + baseline
+    y = np.append(y1[:window_size], y2[window_size:])
+    
+    print(y1.shape, y2.shape)
+    
+    if noise:
+        y1 = y1 + _noise_fct(0.05, window_size)
+        y2 = y2 + _noise_fct(0.05, window_size) 
+    
+    if noise:
+        y = y + _noise_fct(0.05, window_size)
+    
+    y = y[window_size - apex_location: 2*window_size - apex_location]
+    if return_params:
+        return y, (sig1, sig2, baseline, prominance, apex_location)
+    
+    return y
+
+def _noise_fct(
+    noise_std: float, 
+    window_size=10
+) -> np.ndarray:
+    "Random noise based on learned data."
+    # create custom random generator to not interfere with other random calls!
+    rand_generator = np.random.RandomState(0)
+    return (rand_generator.random(2*window_size) * 2 - 1) * noise_std

+ 60 - 0
src/ebpm/plot.py

@@ -4,6 +4,7 @@ import matplotlib.pyplot as plt
 import numpy as np
 from matplotlib import figure
 
+from .curlyBrace import curlyBrace
 
 def ear_time_series(
     ear_r: np.ndarray,
@@ -124,4 +125,63 @@ def matches(
         axs[0].axvspan(match[0], match[1], color='green', alpha=0.3)
     for match in matches_r:
         axs[1].axvspan(match[0], match[1], color='green', alpha=0.3)
+    return fig
+
+    
+def manual_prototype(
+    prototype: np.ndarray, 
+    xmin: int | None = None, 
+    xmax: int | None = None,
+    params: tuple[float, float, float, float, int] | None = None,
+) -> figure.Figure:
+    """
+    Plot the given prototype pattern.
+
+    Args:
+        prototype (np.ndarray): The prototype pattern to plot.
+        xmin (int | None, optional): Minimum x-axis value. Defaults to None.
+        xmax (int | None, optional): Maximum x-axis value. Defaults to None.
+
+    Returns:
+        figure.Figure: The matplotlib figure object.
+    """
+    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 7))
+    ax.plot(prototype)
+    ax.set_xlabel('Frame [#]', fontsize="18")
+    ax.set_ylabel('EAR Value', fontsize="18")
+    ax.set_ylim([0, 0.5])
+    ax.set_xlim([xmin, xmax])
+    fig.suptitle("Manual Prototype Pattern", fontsize="20")
+    
+    if params is None:
+        return fig
+    
+    sig1, sig2, bs, prominance, apex_location = params
+
+    # draw a point for onset and offset
+    onset_x  = apex_location - int(sig1 * 3)
+    offset_x = apex_location + int(sig2 * 3)
+    apex_x   = apex_location
+    apex_y   = prototype[apex_x] 
+    
+    ax.plot(onset_x,  prototype[onset_x],  'ro')
+    ax.plot(offset_x, prototype[offset_x], 'ro')
+    ax.plot(apex_x,   prototype[apex_x],   'ro')
+    
+    # write onset, apex, and offset
+    # slight lower left of the point
+    ax.text(onset_x-5,  prototype[onset_x]-0.02,  'Onset', fontsize=12, color='r')
+    ax.text(offset_x, prototype[offset_x]-0.02, 'Offset', fontsize=12, color='r')
+    ax.text(apex_x,   apex_y-0.02,   'Apex', fontsize=12, color='r')
+    
+    # write the text prominance to the vertical line
+    ax.text(apex_x-3, apex_y+prominance/3, 'Prominance', fontsize=12, color='r', rotation=90)
+    
+    # draw lines to describe the prototype
+    ax.vlines(x=apex_x, ymin=prototype[apex_location], ymax=bs, color='r', linestyle='--')
+    ax.hlines(y=bs, xmin=onset_x, xmax=offset_x, color='r', linestyle='--')
+    
+    # draw curly braces
+    curlyBrace(fig, ax, (onset_x, bs), (apex_x, bs), 0.03,  bool_auto=True, c="r",  str_text="3 * σ1")
+    curlyBrace(fig, ax, (apex_x, bs), (offset_x, bs),0.03, bool_auto=True, c="r", str_text="3 * σ2")
     return fig

برخی فایل ها در این مقایسه diff نمایش داده نمی شوند زیرا تعداد فایل ها بسیار زیاد است