소스 검색

add paper layout + scatter function

phillip.rothenbeck 1 년 전
부모
커밋
0a7b829650
1개의 변경된 파일224개의 추가작업 그리고 17개의 파일을 삭제
  1. 224 17
      src/plotter.py

+ 224 - 17
src/plotter.py

@@ -1,9 +1,12 @@
 import os
 import torch
 import imageio
+import numpy as np
 import matplotlib.pyplot as plt
+import matplotlib.ticker as ticker 
 
 from matplotlib import rcParams
+from itertools import cycle
 
 FRAME_DIR = 'visualizations/temp/'
 VISUALISATION_DIR = 'visualizations/'
@@ -13,7 +16,7 @@ INFECTIOUS_COLOR = '#f56262'
 REMOVED_COLOR = '#83eb5e'
 
 class Plotter:
-    def __init__(self, additional_colors=[], font_size=12, font='Comfortaa', font_color='#595959') -> None:
+    def __init__(self, additional_colors=[], font_size=20, font='serif', font_color='#000000') -> None:
         """Plotter of scientific plots and animations, for dinn.py.
 
         Args:
@@ -24,14 +27,26 @@ class Plotter:
         """
         self.__colors = [SUSCEPTIBLE_COLOR, INFECTIOUS_COLOR, REMOVED_COLOR] + additional_colors
 
+        self.__lines_styles = ['solid', 'dotted', 'dashdot', (5, (10, 3)), (0, (5, 1)), (0, (3, 5, 1, 5)), (0, (3, 1, 1, 1)), (0, (3, 5, 1, 5, 1, 5)), (0, (3, 1, 1, 1, 1, 1))]
+
+        self.__marker_styles = ['o', '^', 's']
+
+        rcParams['text.usetex'] = True
+        rcParams['text.color'] = font_color
+
         rcParams['font.family'] = font
         rcParams['font.size'] = font_size
 
-        rcParams['text.color'] = font_color
+        rcParams['pgf.texsystem'] = 'pdflatex'
+        rcParams['pgf.rcfonts'] = False
+
         rcParams['axes.labelcolor'] = font_color
         rcParams['xtick.color'] = font_color
         rcParams['ytick.color'] = font_color
 
+        plt.rc("text", usetex=True)
+        plt.rc("text.latex", preamble=r"\usepackage{amssymb} \usepackage{wasysym}")
+
         self.__frames = []
 
     def __generate_figure(self, shape=(4, 4)):
@@ -43,7 +58,7 @@ class Plotter:
         Returns:
             Figure: plt.Figure that was generated.
         """
-        fig = plt.figure(figsize=shape)
+        fig = plt.figure(figsize=shape, constrained_layout=True)
         return fig
 
     def reset_animation(self):
@@ -51,6 +66,7 @@ class Plotter:
         """
         self.__frames = []
     
+    #TODO comments
     def plot(self, 
              x, 
              y:list, 
@@ -58,16 +74,19 @@ class Plotter:
              file_name:str,
              title:str, 
              figure_shape:tuple, 
+             event_lookup={},
              is_frame=False, 
              is_background=[], 
+             fill_between=[],
              plot_legend=True, 
              y_log_scale=False, 
              lw=3, 
              legend_loc='best', 
              ylim=(None, None),
+             number_xlabels = 5,
              xlabel='',
              ylabel='', 
-             xlabel_rotation=0):
+             xlabel_rotation=None):
         """Plotting method.
 
         Args:
@@ -87,20 +106,26 @@ class Plotter:
             xlabel (str, optional): Label for the x axis. Defaults to ''.
             ylabel (str, optional): Label for the y axis. Defaults to ''.
         """
-        assert len(y) == len(labels), "There must be the same amount of labels as there are plots."
+        assert len(y) == len(labels), f"There must be the same amount of labels as there are plots.\nNumber Plots: {len(y)}\nNumber Labels: {len(labels)}"
         assert len(is_background) == 0 or len(y) == len(is_background), "If given the back_foreground list must have the same length as labels has."
         fig = self.__generate_figure(shape=figure_shape)
 
         ax = fig.add_subplot(111, facecolor='#dddddd', axisbelow=True)
         ax.set_facecolor('xkcd:white')
-        ax.yaxis.set_tick_params(length=0, which='both')
-        ax.xaxis.set_tick_params(length=0, which='both')
-        ax.grid(which='major', c='black', lw=0.2, ls='-')
+        #ax.yaxis.set_tick_params(length=0, which='both')
+        #ax.xaxis.set_tick_params(length=0, which='both')
+
+        #ax.grid(which='major', c='black', lw=0.2, ls='-')
         ax.set_title(title)
 
-        for spine in ('top', 'right', 'bottom', 'left'):
-            ax.spines[spine].set_visible(False)
+        #for spine in ('top', 'right', 'bottom', 'left'):
+         #   ax.spines[spine].set_visible(False)
+
+        if torch.is_tensor(x):
+            x = x.cpu().detach().numpy()
 
+        linecycler = cycle(self.__lines_styles)
+        j = 0
         for i, array in enumerate(y):
             alpha = 1
             if len(is_background) != 0:
@@ -111,12 +136,19 @@ class Plotter:
                 data = array.cpu().detach().numpy()
             else:
                 data = array
-            
-            if len(is_background) != 0 and alpha == 1:
-                ax.plot(x, data, label=labels[i], alpha=alpha, lw=lw, linestyle='dashed', c=self.__colors[i % len(self.__colors)])
-            else:
-                ax.plot(x, data, label=labels[i], alpha=alpha, lw=lw, c=self.__colors[i % len(self.__colors)])
 
+            space = int(len(x) / number_xlabels)
+            ax.xaxis.set_major_locator(ticker.MultipleLocator(space)) 
+        
+            ax.plot(x, data, label=labels[i], alpha=alpha, lw=lw, linestyle=next(linecycler), c=self.__colors[i % len(self.__colors)])
+            if i < len(fill_between):
+                ax.fill_between(x, data+fill_between[i], data-fill_between[i], facecolor=self.__colors[i % len(self.__colors)], alpha=0.5)
+            j = i
+
+        for event in event_lookup.keys():
+            j += 1
+            plt.axvline(x = event_lookup[event], color = self.__colors[j % len(self.__colors)], label = event, ls=next(linecycler))
+            
         if plot_legend:
             plt.legend(loc=legend_loc)
 
@@ -132,7 +164,8 @@ class Plotter:
         if ylabel != '':
             plt.ylabel(ylabel)
 
-        plt.xticks(rotation=xlabel_rotation)
+        if xlabel_rotation != None:
+            plt.xticks(rotation=xlabel_rotation)
  
         if not os.path.exists(FRAME_DIR):
             os.makedirs(FRAME_DIR)
@@ -144,7 +177,181 @@ class Plotter:
             self.__frames.append(imageio.imread(frame_path))
             os.remove(frame_path)
         else:
-            plt.savefig(VISUALISATION_DIR + f'{file_name}.png')
+            plt.savefig(VISUALISATION_DIR + f'{file_name}.pdf')
+
+        plt.close()
+
+    def cluster_plot(self, 
+                     x, 
+                     y:list, 
+                     labels, 
+                     shape, 
+                     plots_shape, 
+                     file_name:str, 
+                     titles:list, 
+                     number_xlabels=5,
+                     lw=3, 
+                     fill_between=[],
+                     event_lookup={},
+                     xlabel='',
+                     ylabel='',
+                     ylim=(None, None),
+                     y_lim_exception=None,
+                     y_log_scale=False, 
+                     legend_loc=(0.5,0.992),
+                     add_y_space=0.05,
+                     number_of_legend_columns=1,
+                     same_axes=True,
+                     free_axis=(None, None),
+                     plot_all_labels = True):
+        real_shape = (shape[1] * plots_shape[0], shape[0] * plots_shape[1])
+        fig, axes = plt.subplots(*shape, figsize=real_shape, sharex=same_axes, sharey=same_axes)
+        plot_idx = 0
+
+        if torch.is_tensor(x):
+            x = x.cpu().detach().numpy()
+
+        if 1 in shape:
+            for i in range(len(axes)):
+                linecycler = cycle(self.__lines_styles)
+                colorcycler = cycle(self.__colors)
+                for j, array in enumerate(y[i]):
+                    color = next(colorcycler)
+                    if torch.is_tensor(array):
+                        data = array.cpu().detach().numpy()
+                    else:
+                        data = array
+
+                    space = int(len(x) / number_xlabels)
+                    axes[i].xaxis.set_major_locator(ticker.MultipleLocator(space))
+                    axes[i].plot(x, 
+                                 data, 
+                                 linestyle=next(linecycler),
+                                 label=labels[j],
+                                 c=color,
+                                 lw=lw)
+                    axes[i].set_title(titles[i])
+                    if j < len(fill_between[i]):
+                        axes[i].fill_between(x, data+fill_between[i][j], data-fill_between[i][j], facecolor=color, alpha=0.5)
+                for event in event_lookup.keys():
+                    axes[i].axvline(x=event_lookup[event], 
+                                    color=next(colorcycler), 
+                                    label=event, 
+                                    ls=next(linecycler),
+                                    lw=lw)
+            
+                if ylim[0] != None and y_lim_exception != i:
+                    axes[i].set_ylim(ylim)
+                
+                if y_log_scale:
+                    plt.yscale('log')
+        else:
+            for i in range(shape[0]):
+                for j in range(shape[1]):
+                    if (i, j) == free_axis:
+                        axes[i, j].axis('off')
+                    else:
+                        linecycler = cycle(self.__lines_styles)
+                        colorcycler = cycle(self.__colors)
+                        if plot_idx < len(y):
+                            for k, array in enumerate(y[plot_idx]):
+                                if torch.is_tensor(array):
+                                    data = array.cpu().detach().numpy()
+                                else:
+                                    data = array
+                                space = int(len(x) / number_xlabels)
+                                axes[i, j].xaxis.set_major_locator(ticker.MultipleLocator(space))
+                                if len(x) > len(data):
+                                    c = len(data)
+                                else:
+                                    c = len(x)
+                                axes[i, j].plot(x[:c], 
+                                                data, 
+                                                label=labels[k], 
+                                                c=next(colorcycler), 
+                                                lw=lw, 
+                                                linestyle=next(linecycler))
+                            axes[i, j].set_title(titles[plot_idx])
+                            if ylim[0] != None:
+                                axes[i, j].set_ylim(ylim)
+                        plot_idx += 1
+                    if y_log_scale:
+                        plt.yscale('log')
+        if 1 in shape:
+            lines, labels = axes[0].get_legend_handles_labels()
+        else:
+            lines, labels = axes[0, 0].get_legend_handles_labels()
+        fig.legend(lines, labels, loc='upper center', ncol=number_of_legend_columns, bbox_to_anchor=legend_loc)
+
+        for ax in axes.flat:
+            ax.set(xlabel=xlabel, ylabel=ylabel)
+
+        # Hide x labels and tick labels for top plots and y ticks for right plots.
+        if plot_all_labels:
+            for ax in axes.flat:
+                ax.label_outer()
+
+        # Adjust layout to prevent overlap
+        plt.tight_layout(rect=[0, 0, 1, 1-add_y_space])
+        plt.savefig(VISUALISATION_DIR + f'{file_name}.pdf')
+
+    def scatter(self, 
+                x, 
+                y:list, 
+                labels:list, 
+                figure_shape:tuple, 
+                file_name:str, 
+                title:str, 
+                std=[], 
+                true_values=[],
+                true_label='true',
+                plot_legend=True, 
+                legend_loc='best',
+                xlabel='',
+                ylabel='', 
+                xlabel_rotation=None):
+        assert len(y) == len(labels), f"There must be the same amount of labels as there are plots.\nNumber Plots: {len(y)}\nNumber Labels: {len(labels)}"
+        fig = self.__generate_figure(shape=figure_shape)
+
+        ax = fig.add_subplot(111, facecolor='#dddddd', axisbelow=True)
+        ax.set_facecolor('xkcd:white')
+        ax.set_title(title)
+
+        if torch.is_tensor(x):
+            x = x.cpu().detach().numpy()
+
+        markercycler = cycle(self.__marker_styles)
+        for i, array in enumerate(y):
+            
+            if torch.is_tensor(array):
+                data = array.cpu().detach().numpy()
+            else:
+                data = array
+
+            if i >= len(std):
+                ax.scatter(x, data, label=labels[i], c=self.__colors[i % len(self.__colors)], linestyle='None', marker=next(markercycler))
+            if i < len(std):
+                ax.errorbar(x, data, std[i], label=labels[i], c=self.__colors[i % len(self.__colors)], linestyle='None', marker=next(markercycler))
+        
+        linecycler = cycle(self.__lines_styles)
+        for i, true_value in enumerate(true_values):
+            ax.plot(x, np.ones_like(x, dtype='float64') * true_value, label=f'{true_label} {labels[i]}',c=self.__colors[i % len(self.__colors)], ls=next(linecycler))
+            
+        if plot_legend:
+            plt.legend(loc=legend_loc)
+
+        if xlabel != '':
+            plt.xlabel(xlabel, )
+
+        if ylabel != '':
+            plt.ylabel(ylabel)
+
+        if xlabel_rotation != None:
+            plt.xticks(rotation=45, ha='right')
+
+        plt.savefig(VISUALISATION_DIR + f'{file_name}.pdf')
+
+        plt.close()
 
     def animate(self, name: str):
         """Builds animation from images saved in self.frames. Then saves animation as gif.