|
|
@@ -1,9 +1,12 @@
|
|
|
import os
|
|
|
import torch
|
|
|
import imageio
|
|
|
+import numpy as np
|
|
|
import matplotlib.pyplot as plt
|
|
|
+import matplotlib.ticker as ticker
|
|
|
|
|
|
from matplotlib import rcParams
|
|
|
+from itertools import cycle
|
|
|
|
|
|
FRAME_DIR = 'visualizations/temp/'
|
|
|
VISUALISATION_DIR = 'visualizations/'
|
|
|
@@ -13,7 +16,7 @@ INFECTIOUS_COLOR = '#f56262'
|
|
|
REMOVED_COLOR = '#83eb5e'
|
|
|
|
|
|
class Plotter:
|
|
|
- def __init__(self, additional_colors=[], font_size=12, font='Comfortaa', font_color='#595959') -> None:
|
|
|
+ def __init__(self, additional_colors=[], font_size=20, font='serif', font_color='#000000') -> None:
|
|
|
"""Plotter of scientific plots and animations, for dinn.py.
|
|
|
|
|
|
Args:
|
|
|
@@ -24,14 +27,26 @@ class Plotter:
|
|
|
"""
|
|
|
self.__colors = [SUSCEPTIBLE_COLOR, INFECTIOUS_COLOR, REMOVED_COLOR] + additional_colors
|
|
|
|
|
|
+ self.__lines_styles = ['solid', 'dotted', 'dashdot', (5, (10, 3)), (0, (5, 1)), (0, (3, 5, 1, 5)), (0, (3, 1, 1, 1)), (0, (3, 5, 1, 5, 1, 5)), (0, (3, 1, 1, 1, 1, 1))]
|
|
|
+
|
|
|
+ self.__marker_styles = ['o', '^', 's']
|
|
|
+
|
|
|
+ rcParams['text.usetex'] = True
|
|
|
+ rcParams['text.color'] = font_color
|
|
|
+
|
|
|
rcParams['font.family'] = font
|
|
|
rcParams['font.size'] = font_size
|
|
|
|
|
|
- rcParams['text.color'] = font_color
|
|
|
+ rcParams['pgf.texsystem'] = 'pdflatex'
|
|
|
+ rcParams['pgf.rcfonts'] = False
|
|
|
+
|
|
|
rcParams['axes.labelcolor'] = font_color
|
|
|
rcParams['xtick.color'] = font_color
|
|
|
rcParams['ytick.color'] = font_color
|
|
|
|
|
|
+ plt.rc("text", usetex=True)
|
|
|
+ plt.rc("text.latex", preamble=r"\usepackage{amssymb} \usepackage{wasysym}")
|
|
|
+
|
|
|
self.__frames = []
|
|
|
|
|
|
def __generate_figure(self, shape=(4, 4)):
|
|
|
@@ -43,7 +58,7 @@ class Plotter:
|
|
|
Returns:
|
|
|
Figure: plt.Figure that was generated.
|
|
|
"""
|
|
|
- fig = plt.figure(figsize=shape)
|
|
|
+ fig = plt.figure(figsize=shape, constrained_layout=True)
|
|
|
return fig
|
|
|
|
|
|
def reset_animation(self):
|
|
|
@@ -51,6 +66,7 @@ class Plotter:
|
|
|
"""
|
|
|
self.__frames = []
|
|
|
|
|
|
+ #TODO comments
|
|
|
def plot(self,
|
|
|
x,
|
|
|
y:list,
|
|
|
@@ -58,16 +74,19 @@ class Plotter:
|
|
|
file_name:str,
|
|
|
title:str,
|
|
|
figure_shape:tuple,
|
|
|
+ event_lookup={},
|
|
|
is_frame=False,
|
|
|
is_background=[],
|
|
|
+ fill_between=[],
|
|
|
plot_legend=True,
|
|
|
y_log_scale=False,
|
|
|
lw=3,
|
|
|
legend_loc='best',
|
|
|
ylim=(None, None),
|
|
|
+ number_xlabels = 5,
|
|
|
xlabel='',
|
|
|
ylabel='',
|
|
|
- xlabel_rotation=0):
|
|
|
+ xlabel_rotation=None):
|
|
|
"""Plotting method.
|
|
|
|
|
|
Args:
|
|
|
@@ -87,20 +106,26 @@ class Plotter:
|
|
|
xlabel (str, optional): Label for the x axis. Defaults to ''.
|
|
|
ylabel (str, optional): Label for the y axis. Defaults to ''.
|
|
|
"""
|
|
|
- assert len(y) == len(labels), "There must be the same amount of labels as there are plots."
|
|
|
+ assert len(y) == len(labels), f"There must be the same amount of labels as there are plots.\nNumber Plots: {len(y)}\nNumber Labels: {len(labels)}"
|
|
|
assert len(is_background) == 0 or len(y) == len(is_background), "If given the back_foreground list must have the same length as labels has."
|
|
|
fig = self.__generate_figure(shape=figure_shape)
|
|
|
|
|
|
ax = fig.add_subplot(111, facecolor='#dddddd', axisbelow=True)
|
|
|
ax.set_facecolor('xkcd:white')
|
|
|
- ax.yaxis.set_tick_params(length=0, which='both')
|
|
|
- ax.xaxis.set_tick_params(length=0, which='both')
|
|
|
- ax.grid(which='major', c='black', lw=0.2, ls='-')
|
|
|
+ #ax.yaxis.set_tick_params(length=0, which='both')
|
|
|
+ #ax.xaxis.set_tick_params(length=0, which='both')
|
|
|
+
|
|
|
+ #ax.grid(which='major', c='black', lw=0.2, ls='-')
|
|
|
ax.set_title(title)
|
|
|
|
|
|
- for spine in ('top', 'right', 'bottom', 'left'):
|
|
|
- ax.spines[spine].set_visible(False)
|
|
|
+ #for spine in ('top', 'right', 'bottom', 'left'):
|
|
|
+ # ax.spines[spine].set_visible(False)
|
|
|
+
|
|
|
+ if torch.is_tensor(x):
|
|
|
+ x = x.cpu().detach().numpy()
|
|
|
|
|
|
+ linecycler = cycle(self.__lines_styles)
|
|
|
+ j = 0
|
|
|
for i, array in enumerate(y):
|
|
|
alpha = 1
|
|
|
if len(is_background) != 0:
|
|
|
@@ -111,12 +136,19 @@ class Plotter:
|
|
|
data = array.cpu().detach().numpy()
|
|
|
else:
|
|
|
data = array
|
|
|
-
|
|
|
- if len(is_background) != 0 and alpha == 1:
|
|
|
- ax.plot(x, data, label=labels[i], alpha=alpha, lw=lw, linestyle='dashed', c=self.__colors[i % len(self.__colors)])
|
|
|
- else:
|
|
|
- ax.plot(x, data, label=labels[i], alpha=alpha, lw=lw, c=self.__colors[i % len(self.__colors)])
|
|
|
|
|
|
+ space = int(len(x) / number_xlabels)
|
|
|
+ ax.xaxis.set_major_locator(ticker.MultipleLocator(space))
|
|
|
+
|
|
|
+ ax.plot(x, data, label=labels[i], alpha=alpha, lw=lw, linestyle=next(linecycler), c=self.__colors[i % len(self.__colors)])
|
|
|
+ if i < len(fill_between):
|
|
|
+ ax.fill_between(x, data+fill_between[i], data-fill_between[i], facecolor=self.__colors[i % len(self.__colors)], alpha=0.5)
|
|
|
+ j = i
|
|
|
+
|
|
|
+ for event in event_lookup.keys():
|
|
|
+ j += 1
|
|
|
+ plt.axvline(x = event_lookup[event], color = self.__colors[j % len(self.__colors)], label = event, ls=next(linecycler))
|
|
|
+
|
|
|
if plot_legend:
|
|
|
plt.legend(loc=legend_loc)
|
|
|
|
|
|
@@ -132,7 +164,8 @@ class Plotter:
|
|
|
if ylabel != '':
|
|
|
plt.ylabel(ylabel)
|
|
|
|
|
|
- plt.xticks(rotation=xlabel_rotation)
|
|
|
+ if xlabel_rotation != None:
|
|
|
+ plt.xticks(rotation=xlabel_rotation)
|
|
|
|
|
|
if not os.path.exists(FRAME_DIR):
|
|
|
os.makedirs(FRAME_DIR)
|
|
|
@@ -144,7 +177,181 @@ class Plotter:
|
|
|
self.__frames.append(imageio.imread(frame_path))
|
|
|
os.remove(frame_path)
|
|
|
else:
|
|
|
- plt.savefig(VISUALISATION_DIR + f'{file_name}.png')
|
|
|
+ plt.savefig(VISUALISATION_DIR + f'{file_name}.pdf')
|
|
|
+
|
|
|
+ plt.close()
|
|
|
+
|
|
|
+ def cluster_plot(self,
|
|
|
+ x,
|
|
|
+ y:list,
|
|
|
+ labels,
|
|
|
+ shape,
|
|
|
+ plots_shape,
|
|
|
+ file_name:str,
|
|
|
+ titles:list,
|
|
|
+ number_xlabels=5,
|
|
|
+ lw=3,
|
|
|
+ fill_between=[],
|
|
|
+ event_lookup={},
|
|
|
+ xlabel='',
|
|
|
+ ylabel='',
|
|
|
+ ylim=(None, None),
|
|
|
+ y_lim_exception=None,
|
|
|
+ y_log_scale=False,
|
|
|
+ legend_loc=(0.5,0.992),
|
|
|
+ add_y_space=0.05,
|
|
|
+ number_of_legend_columns=1,
|
|
|
+ same_axes=True,
|
|
|
+ free_axis=(None, None),
|
|
|
+ plot_all_labels = True):
|
|
|
+ real_shape = (shape[1] * plots_shape[0], shape[0] * plots_shape[1])
|
|
|
+ fig, axes = plt.subplots(*shape, figsize=real_shape, sharex=same_axes, sharey=same_axes)
|
|
|
+ plot_idx = 0
|
|
|
+
|
|
|
+ if torch.is_tensor(x):
|
|
|
+ x = x.cpu().detach().numpy()
|
|
|
+
|
|
|
+ if 1 in shape:
|
|
|
+ for i in range(len(axes)):
|
|
|
+ linecycler = cycle(self.__lines_styles)
|
|
|
+ colorcycler = cycle(self.__colors)
|
|
|
+ for j, array in enumerate(y[i]):
|
|
|
+ color = next(colorcycler)
|
|
|
+ if torch.is_tensor(array):
|
|
|
+ data = array.cpu().detach().numpy()
|
|
|
+ else:
|
|
|
+ data = array
|
|
|
+
|
|
|
+ space = int(len(x) / number_xlabels)
|
|
|
+ axes[i].xaxis.set_major_locator(ticker.MultipleLocator(space))
|
|
|
+ axes[i].plot(x,
|
|
|
+ data,
|
|
|
+ linestyle=next(linecycler),
|
|
|
+ label=labels[j],
|
|
|
+ c=color,
|
|
|
+ lw=lw)
|
|
|
+ axes[i].set_title(titles[i])
|
|
|
+ if j < len(fill_between[i]):
|
|
|
+ axes[i].fill_between(x, data+fill_between[i][j], data-fill_between[i][j], facecolor=color, alpha=0.5)
|
|
|
+ for event in event_lookup.keys():
|
|
|
+ axes[i].axvline(x=event_lookup[event],
|
|
|
+ color=next(colorcycler),
|
|
|
+ label=event,
|
|
|
+ ls=next(linecycler),
|
|
|
+ lw=lw)
|
|
|
+
|
|
|
+ if ylim[0] != None and y_lim_exception != i:
|
|
|
+ axes[i].set_ylim(ylim)
|
|
|
+
|
|
|
+ if y_log_scale:
|
|
|
+ plt.yscale('log')
|
|
|
+ else:
|
|
|
+ for i in range(shape[0]):
|
|
|
+ for j in range(shape[1]):
|
|
|
+ if (i, j) == free_axis:
|
|
|
+ axes[i, j].axis('off')
|
|
|
+ else:
|
|
|
+ linecycler = cycle(self.__lines_styles)
|
|
|
+ colorcycler = cycle(self.__colors)
|
|
|
+ if plot_idx < len(y):
|
|
|
+ for k, array in enumerate(y[plot_idx]):
|
|
|
+ if torch.is_tensor(array):
|
|
|
+ data = array.cpu().detach().numpy()
|
|
|
+ else:
|
|
|
+ data = array
|
|
|
+ space = int(len(x) / number_xlabels)
|
|
|
+ axes[i, j].xaxis.set_major_locator(ticker.MultipleLocator(space))
|
|
|
+ if len(x) > len(data):
|
|
|
+ c = len(data)
|
|
|
+ else:
|
|
|
+ c = len(x)
|
|
|
+ axes[i, j].plot(x[:c],
|
|
|
+ data,
|
|
|
+ label=labels[k],
|
|
|
+ c=next(colorcycler),
|
|
|
+ lw=lw,
|
|
|
+ linestyle=next(linecycler))
|
|
|
+ axes[i, j].set_title(titles[plot_idx])
|
|
|
+ if ylim[0] != None:
|
|
|
+ axes[i, j].set_ylim(ylim)
|
|
|
+ plot_idx += 1
|
|
|
+ if y_log_scale:
|
|
|
+ plt.yscale('log')
|
|
|
+ if 1 in shape:
|
|
|
+ lines, labels = axes[0].get_legend_handles_labels()
|
|
|
+ else:
|
|
|
+ lines, labels = axes[0, 0].get_legend_handles_labels()
|
|
|
+ fig.legend(lines, labels, loc='upper center', ncol=number_of_legend_columns, bbox_to_anchor=legend_loc)
|
|
|
+
|
|
|
+ for ax in axes.flat:
|
|
|
+ ax.set(xlabel=xlabel, ylabel=ylabel)
|
|
|
+
|
|
|
+ # Hide x labels and tick labels for top plots and y ticks for right plots.
|
|
|
+ if plot_all_labels:
|
|
|
+ for ax in axes.flat:
|
|
|
+ ax.label_outer()
|
|
|
+
|
|
|
+ # Adjust layout to prevent overlap
|
|
|
+ plt.tight_layout(rect=[0, 0, 1, 1-add_y_space])
|
|
|
+ plt.savefig(VISUALISATION_DIR + f'{file_name}.pdf')
|
|
|
+
|
|
|
+ def scatter(self,
|
|
|
+ x,
|
|
|
+ y:list,
|
|
|
+ labels:list,
|
|
|
+ figure_shape:tuple,
|
|
|
+ file_name:str,
|
|
|
+ title:str,
|
|
|
+ std=[],
|
|
|
+ true_values=[],
|
|
|
+ true_label='true',
|
|
|
+ plot_legend=True,
|
|
|
+ legend_loc='best',
|
|
|
+ xlabel='',
|
|
|
+ ylabel='',
|
|
|
+ xlabel_rotation=None):
|
|
|
+ assert len(y) == len(labels), f"There must be the same amount of labels as there are plots.\nNumber Plots: {len(y)}\nNumber Labels: {len(labels)}"
|
|
|
+ fig = self.__generate_figure(shape=figure_shape)
|
|
|
+
|
|
|
+ ax = fig.add_subplot(111, facecolor='#dddddd', axisbelow=True)
|
|
|
+ ax.set_facecolor('xkcd:white')
|
|
|
+ ax.set_title(title)
|
|
|
+
|
|
|
+ if torch.is_tensor(x):
|
|
|
+ x = x.cpu().detach().numpy()
|
|
|
+
|
|
|
+ markercycler = cycle(self.__marker_styles)
|
|
|
+ for i, array in enumerate(y):
|
|
|
+
|
|
|
+ if torch.is_tensor(array):
|
|
|
+ data = array.cpu().detach().numpy()
|
|
|
+ else:
|
|
|
+ data = array
|
|
|
+
|
|
|
+ if i >= len(std):
|
|
|
+ ax.scatter(x, data, label=labels[i], c=self.__colors[i % len(self.__colors)], linestyle='None', marker=next(markercycler))
|
|
|
+ if i < len(std):
|
|
|
+ ax.errorbar(x, data, std[i], label=labels[i], c=self.__colors[i % len(self.__colors)], linestyle='None', marker=next(markercycler))
|
|
|
+
|
|
|
+ linecycler = cycle(self.__lines_styles)
|
|
|
+ for i, true_value in enumerate(true_values):
|
|
|
+ ax.plot(x, np.ones_like(x, dtype='float64') * true_value, label=f'{true_label} {labels[i]}',c=self.__colors[i % len(self.__colors)], ls=next(linecycler))
|
|
|
+
|
|
|
+ if plot_legend:
|
|
|
+ plt.legend(loc=legend_loc)
|
|
|
+
|
|
|
+ if xlabel != '':
|
|
|
+ plt.xlabel(xlabel, )
|
|
|
+
|
|
|
+ if ylabel != '':
|
|
|
+ plt.ylabel(ylabel)
|
|
|
+
|
|
|
+ if xlabel_rotation != None:
|
|
|
+ plt.xticks(rotation=45, ha='right')
|
|
|
+
|
|
|
+ plt.savefig(VISUALISATION_DIR + f'{file_name}.pdf')
|
|
|
+
|
|
|
+ plt.close()
|
|
|
|
|
|
def animate(self, name: str):
|
|
|
"""Builds animation from images saved in self.frames. Then saves animation as gif.
|