phillip.rothenbeck 1 年間 前
コミット
ca5b462568
1 ファイル変更152 行追加0 行削除
  1. 152 0
      src/plotter.py

+ 152 - 0
src/plotter.py

@@ -0,0 +1,152 @@
+import os
+import torch
+import imageio
+import matplotlib.pyplot as plt
+
+from matplotlib import rcParams
+
+FRAME_DIR = 'visualizations/temp/'
+VISUALISATION_DIR = 'visualizations/'
+
+SUSCEPTIBLE_COLOR = '#6399f7'
+INFECTIOUS_COLOR = '#f56262'
+REMOVED_COLOR = '#83eb5e'
+
+class Plotter:
+    def __init__(self, additional_colors=[], font_size=12, font='Comfortaa', font_color='#595959') -> None:
+        """Plotter of scientific plots and animations, for dinn.py.
+
+        Args:
+            additional_colors (list, optional): List of strings that describe additional colors for plotting (used in order). Defaults to [].
+            font_size (int, optional): Size of the fonts in the plots. Defaults to 12.
+            font (str, optional): Font family used in the plots. Defaults to 'Comfortaa'.
+            font_color (str, optional): Color of the fonts used in the plots. Defaults to '#595959'.
+        """
+        self.__colors = [SUSCEPTIBLE_COLOR, INFECTIOUS_COLOR, REMOVED_COLOR] + additional_colors
+
+        rcParams['font.family'] = font
+        rcParams['font.size'] = font_size
+
+        rcParams['text.color'] = font_color
+        rcParams['axes.labelcolor'] = font_color
+        rcParams['xtick.color'] = font_color
+        rcParams['ytick.color'] = font_color
+
+        self.__frames = []
+
+    def __generate_figure(self, shape=(4, 4)):
+        """Generate figure for a plot
+
+        Args:
+            shape (tuple, optional): Size of the plot in dimensions. Defaults to (4, 4).
+
+        Returns:
+            Figure: plt.Figure that was generated.
+        """
+        fig = plt.figure(figsize=shape)
+        return fig
+
+    def reset_animation(self):
+        """Delete all frames that were saved.
+        """
+        self.__frames = []
+    
+    def plot(self, 
+             x, 
+             y:list, 
+             labels:list,
+             file_name:str,
+             title:str, 
+             figure_shape:tuple, 
+             is_frame=False, 
+             is_background=[], 
+             plot_legend=True, 
+             y_log_scale=False, 
+             lw=3, 
+             legend_loc='best', 
+             ylim=(None, None),
+             xlabel='',
+             ylabel=''):
+        """Plotting method.
+
+        Args:
+            x (ndarray): Array of the definition values for the plots (x-values).
+            y (list): List of torch.Tensors or numpy.arrays that are supposed to be plotted in the graph.
+            labels (list): List of the string labels for each plot in the graph. Needs to have the same length as y.
+            file_name (str): Name for the file that will be saved containing the plots.
+            title (str): Title of the plot.
+            figure_shape (tuple): Tuple of integers. Shape of the figure.
+            is_frame (bool, optional): Decides if plot will be used as frame of a animation. Defaults to False.
+            is_background (list, optional): List of strings that show if the corresponding element is supposed to be highlighted('f') or not('b'). Either needs to be empty(no highlighting) or the same length as y. Defaults to [].
+            plot_legend (bool, optional): Decides whether a legend will be plotted. Defaults to True.
+            y_log_scale (bool, optional): Decides whether the y axis is scaled in log. Defaults to False.
+            lw (int, optional): Size of the lines plotted. Defaults to 3.
+            legend_loc (str, optional): Position for the legend. Will be passed to plt.legend(loc=legend_loc). Is not used, when plot_legend=False. Defaults to 'best'.
+            ylim (tuple, optional): Tuple that holds the limits for the y axis. Will be passed to plt.ylim(ylim). Default does not set any limits.. Defaults to (None, None).
+            xlabel (str, optional): Label for the x axis. Defaults to ''.
+            ylabel (str, optional): Label for the y axis. Defaults to ''.
+        """
+        assert len(y) == len(labels), "There must be the same amount of labels as there are plots."
+        assert len(is_background) == 0 or len(y) == len(is_background), "If given the back_foreground list must have the same length as labels has."
+        fig = self.__generate_figure(shape=figure_shape)
+
+        ax = fig.add_subplot(111, facecolor='#dddddd', axisbelow=True)
+        ax.set_facecolor('xkcd:white')
+        ax.yaxis.set_tick_params(length=0, which='both')
+        ax.xaxis.set_tick_params(length=0, which='both')
+        ax.grid(which='major', c='black', lw=0.2, ls='-')
+        ax.set_title(title)
+
+        for spine in ('top', 'right', 'bottom', 'left'):
+            ax.spines[spine].set_visible(False)
+
+        for i, array in enumerate(y):
+            alpha = 1
+            if len(is_background) != 0:
+                if is_background[i]:
+                    alpha = 0.25
+            
+            if torch.is_tensor(array):
+                data = array.cpu().detach().numpy()
+            else:
+                data = array
+            
+            if len(is_background) != 0 and alpha == 1:
+                ax.plot(x, data, label=labels[i], alpha=alpha, lw=lw, linestyle='dashed', c=self.__colors[i % len(self.__colors)])
+            else:
+                ax.plot(x, data, label=labels[i], alpha=alpha, lw=lw, c=self.__colors[i % len(self.__colors)])
+
+        if plot_legend:
+            plt.legend(loc=legend_loc)
+
+        if ylim[0] != None:
+            plt.ylim(*ylim)
+
+        if y_log_scale:
+            plt.yscale('log')
+
+        if xlabel != '':
+            plt.xlabel(xlabel)
+
+        if ylabel != '':
+            plt.ylabel(ylabel)
+ 
+        if not os.path.exists(FRAME_DIR):
+            os.makedirs(FRAME_DIR)
+
+        if is_frame:
+            frame_path = FRAME_DIR + 'frame.png'
+            plt.savefig(frame_path)
+            plt.close(fig)
+            self.__frames.append(imageio.imread(frame_path))
+            os.remove(frame_path)
+        else:
+            plt.savefig(VISUALISATION_DIR + f'{file_name}.png')
+
+    def animate(self, name: str):
+        """Builds animation from images saved in self.frames. Then saves animation as gif.
+
+        Args:
+            name (str): Name of the gif file.
+        """
+        imageio.mimsave(VISUALISATION_DIR + f'{name}.gif', self.__frames, duration=0.5)