|
@@ -0,0 +1,152 @@
|
|
|
|
|
+import os
|
|
|
|
|
+import torch
|
|
|
|
|
+import imageio
|
|
|
|
|
+import matplotlib.pyplot as plt
|
|
|
|
|
+
|
|
|
|
|
+from matplotlib import rcParams
|
|
|
|
|
+
|
|
|
|
|
+FRAME_DIR = 'visualizations/temp/'
|
|
|
|
|
+VISUALISATION_DIR = 'visualizations/'
|
|
|
|
|
+
|
|
|
|
|
+SUSCEPTIBLE_COLOR = '#6399f7'
|
|
|
|
|
+INFECTIOUS_COLOR = '#f56262'
|
|
|
|
|
+REMOVED_COLOR = '#83eb5e'
|
|
|
|
|
+
|
|
|
|
|
+class Plotter:
|
|
|
|
|
+ def __init__(self, additional_colors=[], font_size=12, font='Comfortaa', font_color='#595959') -> None:
|
|
|
|
|
+ """Plotter of scientific plots and animations, for dinn.py.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ additional_colors (list, optional): List of strings that describe additional colors for plotting (used in order). Defaults to [].
|
|
|
|
|
+ font_size (int, optional): Size of the fonts in the plots. Defaults to 12.
|
|
|
|
|
+ font (str, optional): Font family used in the plots. Defaults to 'Comfortaa'.
|
|
|
|
|
+ font_color (str, optional): Color of the fonts used in the plots. Defaults to '#595959'.
|
|
|
|
|
+ """
|
|
|
|
|
+ self.__colors = [SUSCEPTIBLE_COLOR, INFECTIOUS_COLOR, REMOVED_COLOR] + additional_colors
|
|
|
|
|
+
|
|
|
|
|
+ rcParams['font.family'] = font
|
|
|
|
|
+ rcParams['font.size'] = font_size
|
|
|
|
|
+
|
|
|
|
|
+ rcParams['text.color'] = font_color
|
|
|
|
|
+ rcParams['axes.labelcolor'] = font_color
|
|
|
|
|
+ rcParams['xtick.color'] = font_color
|
|
|
|
|
+ rcParams['ytick.color'] = font_color
|
|
|
|
|
+
|
|
|
|
|
+ self.__frames = []
|
|
|
|
|
+
|
|
|
|
|
+ def __generate_figure(self, shape=(4, 4)):
|
|
|
|
|
+ """Generate figure for a plot
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ shape (tuple, optional): Size of the plot in dimensions. Defaults to (4, 4).
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ Figure: plt.Figure that was generated.
|
|
|
|
|
+ """
|
|
|
|
|
+ fig = plt.figure(figsize=shape)
|
|
|
|
|
+ return fig
|
|
|
|
|
+
|
|
|
|
|
+ def reset_animation(self):
|
|
|
|
|
+ """Delete all frames that were saved.
|
|
|
|
|
+ """
|
|
|
|
|
+ self.__frames = []
|
|
|
|
|
+
|
|
|
|
|
+ def plot(self,
|
|
|
|
|
+ x,
|
|
|
|
|
+ y:list,
|
|
|
|
|
+ labels:list,
|
|
|
|
|
+ file_name:str,
|
|
|
|
|
+ title:str,
|
|
|
|
|
+ figure_shape:tuple,
|
|
|
|
|
+ is_frame=False,
|
|
|
|
|
+ is_background=[],
|
|
|
|
|
+ plot_legend=True,
|
|
|
|
|
+ y_log_scale=False,
|
|
|
|
|
+ lw=3,
|
|
|
|
|
+ legend_loc='best',
|
|
|
|
|
+ ylim=(None, None),
|
|
|
|
|
+ xlabel='',
|
|
|
|
|
+ ylabel=''):
|
|
|
|
|
+ """Plotting method.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ x (ndarray): Array of the definition values for the plots (x-values).
|
|
|
|
|
+ y (list): List of torch.Tensors or numpy.arrays that are supposed to be plotted in the graph.
|
|
|
|
|
+ labels (list): List of the string labels for each plot in the graph. Needs to have the same length as y.
|
|
|
|
|
+ file_name (str): Name for the file that will be saved containing the plots.
|
|
|
|
|
+ title (str): Title of the plot.
|
|
|
|
|
+ figure_shape (tuple): Tuple of integers. Shape of the figure.
|
|
|
|
|
+ is_frame (bool, optional): Decides if plot will be used as frame of a animation. Defaults to False.
|
|
|
|
|
+ is_background (list, optional): List of strings that show if the corresponding element is supposed to be highlighted('f') or not('b'). Either needs to be empty(no highlighting) or the same length as y. Defaults to [].
|
|
|
|
|
+ plot_legend (bool, optional): Decides whether a legend will be plotted. Defaults to True.
|
|
|
|
|
+ y_log_scale (bool, optional): Decides whether the y axis is scaled in log. Defaults to False.
|
|
|
|
|
+ lw (int, optional): Size of the lines plotted. Defaults to 3.
|
|
|
|
|
+ legend_loc (str, optional): Position for the legend. Will be passed to plt.legend(loc=legend_loc). Is not used, when plot_legend=False. Defaults to 'best'.
|
|
|
|
|
+ ylim (tuple, optional): Tuple that holds the limits for the y axis. Will be passed to plt.ylim(ylim). Default does not set any limits.. Defaults to (None, None).
|
|
|
|
|
+ xlabel (str, optional): Label for the x axis. Defaults to ''.
|
|
|
|
|
+ ylabel (str, optional): Label for the y axis. Defaults to ''.
|
|
|
|
|
+ """
|
|
|
|
|
+ assert len(y) == len(labels), "There must be the same amount of labels as there are plots."
|
|
|
|
|
+ assert len(is_background) == 0 or len(y) == len(is_background), "If given the back_foreground list must have the same length as labels has."
|
|
|
|
|
+ fig = self.__generate_figure(shape=figure_shape)
|
|
|
|
|
+
|
|
|
|
|
+ ax = fig.add_subplot(111, facecolor='#dddddd', axisbelow=True)
|
|
|
|
|
+ ax.set_facecolor('xkcd:white')
|
|
|
|
|
+ ax.yaxis.set_tick_params(length=0, which='both')
|
|
|
|
|
+ ax.xaxis.set_tick_params(length=0, which='both')
|
|
|
|
|
+ ax.grid(which='major', c='black', lw=0.2, ls='-')
|
|
|
|
|
+ ax.set_title(title)
|
|
|
|
|
+
|
|
|
|
|
+ for spine in ('top', 'right', 'bottom', 'left'):
|
|
|
|
|
+ ax.spines[spine].set_visible(False)
|
|
|
|
|
+
|
|
|
|
|
+ for i, array in enumerate(y):
|
|
|
|
|
+ alpha = 1
|
|
|
|
|
+ if len(is_background) != 0:
|
|
|
|
|
+ if is_background[i]:
|
|
|
|
|
+ alpha = 0.25
|
|
|
|
|
+
|
|
|
|
|
+ if torch.is_tensor(array):
|
|
|
|
|
+ data = array.cpu().detach().numpy()
|
|
|
|
|
+ else:
|
|
|
|
|
+ data = array
|
|
|
|
|
+
|
|
|
|
|
+ if len(is_background) != 0 and alpha == 1:
|
|
|
|
|
+ ax.plot(x, data, label=labels[i], alpha=alpha, lw=lw, linestyle='dashed', c=self.__colors[i % len(self.__colors)])
|
|
|
|
|
+ else:
|
|
|
|
|
+ ax.plot(x, data, label=labels[i], alpha=alpha, lw=lw, c=self.__colors[i % len(self.__colors)])
|
|
|
|
|
+
|
|
|
|
|
+ if plot_legend:
|
|
|
|
|
+ plt.legend(loc=legend_loc)
|
|
|
|
|
+
|
|
|
|
|
+ if ylim[0] != None:
|
|
|
|
|
+ plt.ylim(*ylim)
|
|
|
|
|
+
|
|
|
|
|
+ if y_log_scale:
|
|
|
|
|
+ plt.yscale('log')
|
|
|
|
|
+
|
|
|
|
|
+ if xlabel != '':
|
|
|
|
|
+ plt.xlabel(xlabel)
|
|
|
|
|
+
|
|
|
|
|
+ if ylabel != '':
|
|
|
|
|
+ plt.ylabel(ylabel)
|
|
|
|
|
+
|
|
|
|
|
+ if not os.path.exists(FRAME_DIR):
|
|
|
|
|
+ os.makedirs(FRAME_DIR)
|
|
|
|
|
+
|
|
|
|
|
+ if is_frame:
|
|
|
|
|
+ frame_path = FRAME_DIR + 'frame.png'
|
|
|
|
|
+ plt.savefig(frame_path)
|
|
|
|
|
+ plt.close(fig)
|
|
|
|
|
+ self.__frames.append(imageio.imread(frame_path))
|
|
|
|
|
+ os.remove(frame_path)
|
|
|
|
|
+ else:
|
|
|
|
|
+ plt.savefig(VISUALISATION_DIR + f'{file_name}.png')
|
|
|
|
|
+
|
|
|
|
|
+ def animate(self, name: str):
|
|
|
|
|
+ """Builds animation from images saved in self.frames. Then saves animation as gif.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ name (str): Name of the gif file.
|
|
|
|
|
+ """
|
|
|
|
|
+ imageio.mimsave(VISUALISATION_DIR + f'{name}.gif', self.__frames, duration=0.5)
|