|
@@ -3,7 +3,7 @@ import torch
|
|
|
import imageio
|
|
|
import numpy as np
|
|
|
import matplotlib.pyplot as plt
|
|
|
-import matplotlib.ticker as ticker
|
|
|
+import matplotlib.ticker as ticker
|
|
|
|
|
|
from matplotlib import rcParams
|
|
|
from itertools import cycle
|
|
@@ -15,6 +15,7 @@ SUSCEPTIBLE_COLOR = '#6399f7'
|
|
|
INFECTIOUS_COLOR = '#f56262'
|
|
|
REMOVED_COLOR = '#83eb5e'
|
|
|
|
|
|
+
|
|
|
class Plotter:
|
|
|
def __init__(self, additional_colors=[], font_size=20, font='serif', font_color='#000000') -> None:
|
|
|
"""Plotter of scientific plots and animations, for dinn.py.
|
|
@@ -65,27 +66,27 @@ class Plotter:
|
|
|
"""Delete all frames that were saved.
|
|
|
"""
|
|
|
self.__frames = []
|
|
|
-
|
|
|
- #TODO comments
|
|
|
- def plot(self,
|
|
|
- x,
|
|
|
- y:list,
|
|
|
- labels:list,
|
|
|
- file_name:str,
|
|
|
- title:str,
|
|
|
- figure_shape:tuple,
|
|
|
+
|
|
|
+ # TODO comments
|
|
|
+ def plot(self,
|
|
|
+ x,
|
|
|
+ y: list,
|
|
|
+ labels: list,
|
|
|
+ file_name: str,
|
|
|
+ title: str,
|
|
|
+ figure_shape: tuple,
|
|
|
event_lookup={},
|
|
|
- is_frame=False,
|
|
|
- is_background=[],
|
|
|
+ is_frame=False,
|
|
|
+ is_background=[],
|
|
|
fill_between=[],
|
|
|
- plot_legend=True,
|
|
|
- y_log_scale=False,
|
|
|
- lw=3,
|
|
|
- legend_loc='best',
|
|
|
+ plot_legend=True,
|
|
|
+ y_log_scale=False,
|
|
|
+ lw=3,
|
|
|
+ legend_loc='best',
|
|
|
ylim=(None, None),
|
|
|
- number_xlabels = 5,
|
|
|
+ number_xlabels=5,
|
|
|
xlabel='',
|
|
|
- ylabel='',
|
|
|
+ ylabel='',
|
|
|
xlabel_rotation=None):
|
|
|
"""Plotting method.
|
|
|
|
|
@@ -112,14 +113,14 @@ class Plotter:
|
|
|
|
|
|
ax = fig.add_subplot(111, facecolor='#dddddd', axisbelow=True)
|
|
|
ax.set_facecolor('xkcd:white')
|
|
|
- #ax.yaxis.set_tick_params(length=0, which='both')
|
|
|
- #ax.xaxis.set_tick_params(length=0, which='both')
|
|
|
+ # ax.yaxis.set_tick_params(length=0, which='both')
|
|
|
+ # ax.xaxis.set_tick_params(length=0, which='both')
|
|
|
|
|
|
- #ax.grid(which='major', c='black', lw=0.2, ls='-')
|
|
|
+ # ax.grid(which='major', c='black', lw=0.2, ls='-')
|
|
|
ax.set_title(title)
|
|
|
|
|
|
- #for spine in ('top', 'right', 'bottom', 'left'):
|
|
|
- # ax.spines[spine].set_visible(False)
|
|
|
+ # for spine in ('top', 'right', 'bottom', 'left'):
|
|
|
+ # ax.spines[spine].set_visible(False)
|
|
|
|
|
|
if torch.is_tensor(x):
|
|
|
x = x.cpu().detach().numpy()
|
|
@@ -131,24 +132,24 @@ class Plotter:
|
|
|
if len(is_background) != 0:
|
|
|
if is_background[i]:
|
|
|
alpha = 0.25
|
|
|
-
|
|
|
+
|
|
|
if torch.is_tensor(array):
|
|
|
data = array.cpu().detach().numpy()
|
|
|
else:
|
|
|
data = array
|
|
|
|
|
|
space = int(len(x) / number_xlabels)
|
|
|
- ax.xaxis.set_major_locator(ticker.MultipleLocator(space))
|
|
|
-
|
|
|
+ ax.xaxis.set_major_locator(ticker.MultipleLocator(space))
|
|
|
+
|
|
|
ax.plot(x, data, label=labels[i], alpha=alpha, lw=lw, linestyle=next(linecycler), c=self.__colors[i % len(self.__colors)])
|
|
|
if i < len(fill_between):
|
|
|
- ax.fill_between(x, data+fill_between[i], data-fill_between[i], facecolor=self.__colors[i % len(self.__colors)], alpha=0.5)
|
|
|
+ ax.fill_between(x, data + fill_between[i], data - fill_between[i], facecolor=self.__colors[i % len(self.__colors)], alpha=0.5)
|
|
|
j = i
|
|
|
|
|
|
for event in event_lookup.keys():
|
|
|
j += 1
|
|
|
- plt.axvline(x = event_lookup[event], color = self.__colors[j % len(self.__colors)], label = event, ls=next(linecycler))
|
|
|
-
|
|
|
+ plt.axvline(x=event_lookup[event], color=self.__colors[j % len(self.__colors)], label=event, ls=next(linecycler))
|
|
|
+
|
|
|
if plot_legend:
|
|
|
plt.legend(loc=legend_loc)
|
|
|
|
|
@@ -166,7 +167,7 @@ class Plotter:
|
|
|
|
|
|
if xlabel_rotation != None:
|
|
|
plt.xticks(rotation=xlabel_rotation)
|
|
|
-
|
|
|
+
|
|
|
if not os.path.exists(FRAME_DIR):
|
|
|
os.makedirs(FRAME_DIR)
|
|
|
|
|
@@ -181,29 +182,29 @@ class Plotter:
|
|
|
|
|
|
plt.close()
|
|
|
|
|
|
- def cluster_plot(self,
|
|
|
- x,
|
|
|
- y:list,
|
|
|
- labels,
|
|
|
- shape,
|
|
|
- plots_shape,
|
|
|
- file_name:str,
|
|
|
- titles:list,
|
|
|
+ def cluster_plot(self,
|
|
|
+ x,
|
|
|
+ y: list,
|
|
|
+ labels,
|
|
|
+ shape,
|
|
|
+ plots_shape,
|
|
|
+ file_name: str,
|
|
|
+ titles: list,
|
|
|
number_xlabels=5,
|
|
|
- lw=3,
|
|
|
+ lw=3,
|
|
|
fill_between=[],
|
|
|
event_lookup={},
|
|
|
xlabel='',
|
|
|
ylabel='',
|
|
|
ylim=(None, None),
|
|
|
y_lim_exception=None,
|
|
|
- y_log_scale=False,
|
|
|
- legend_loc=(0.5,0.992),
|
|
|
+ y_log_scale=False,
|
|
|
+ legend_loc=(0.5, 0.992),
|
|
|
add_y_space=0.05,
|
|
|
number_of_legend_columns=1,
|
|
|
same_axes=True,
|
|
|
free_axis=(None, None),
|
|
|
- plot_all_labels = True):
|
|
|
+ plot_all_labels=True):
|
|
|
real_shape = (shape[1] * plots_shape[0], shape[0] * plots_shape[1])
|
|
|
fig, axes = plt.subplots(*shape, figsize=real_shape, sharex=same_axes, sharey=same_axes)
|
|
|
plot_idx = 0
|
|
@@ -224,25 +225,25 @@ class Plotter:
|
|
|
|
|
|
space = int(len(x) / number_xlabels)
|
|
|
axes[i].xaxis.set_major_locator(ticker.MultipleLocator(space))
|
|
|
- axes[i].plot(x,
|
|
|
- data,
|
|
|
+ axes[i].plot(x,
|
|
|
+ data,
|
|
|
linestyle=next(linecycler),
|
|
|
label=labels[j],
|
|
|
c=color,
|
|
|
lw=lw)
|
|
|
axes[i].set_title(titles[i])
|
|
|
if j < len(fill_between[i]):
|
|
|
- axes[i].fill_between(x, data+fill_between[i][j], data-fill_between[i][j], facecolor=color, alpha=0.5)
|
|
|
+ axes[i].fill_between(x, data + fill_between[i][j], data - fill_between[i][j], facecolor=color, alpha=0.5)
|
|
|
for event in event_lookup.keys():
|
|
|
- axes[i].axvline(x=event_lookup[event],
|
|
|
- color=next(colorcycler),
|
|
|
- label=event,
|
|
|
+ axes[i].axvline(x=event_lookup[event],
|
|
|
+ color=next(colorcycler),
|
|
|
+ label=event,
|
|
|
ls=next(linecycler),
|
|
|
lw=lw)
|
|
|
-
|
|
|
+
|
|
|
if ylim[0] != None and y_lim_exception != i:
|
|
|
axes[i].set_ylim(ylim)
|
|
|
-
|
|
|
+
|
|
|
if y_log_scale:
|
|
|
plt.yscale('log')
|
|
|
else:
|
|
@@ -265,11 +266,11 @@ class Plotter:
|
|
|
c = len(data)
|
|
|
else:
|
|
|
c = len(x)
|
|
|
- axes[i, j].plot(x[:c],
|
|
|
- data,
|
|
|
- label=labels[k],
|
|
|
- c=next(colorcycler),
|
|
|
- lw=lw,
|
|
|
+ axes[i, j].plot(x[:c],
|
|
|
+ data,
|
|
|
+ label=labels[k],
|
|
|
+ c=next(colorcycler),
|
|
|
+ lw=lw,
|
|
|
linestyle=next(linecycler))
|
|
|
axes[i, j].set_title(titles[plot_idx])
|
|
|
if ylim[0] != None:
|
|
@@ -292,23 +293,23 @@ class Plotter:
|
|
|
ax.label_outer()
|
|
|
|
|
|
# Adjust layout to prevent overlap
|
|
|
- plt.tight_layout(rect=[0, 0, 1, 1-add_y_space])
|
|
|
+ plt.tight_layout(rect=[0, 0, 1, 1 - add_y_space])
|
|
|
plt.savefig(VISUALISATION_DIR + f'{file_name}.pdf')
|
|
|
|
|
|
- def scatter(self,
|
|
|
- x,
|
|
|
- y:list,
|
|
|
- labels:list,
|
|
|
- figure_shape:tuple,
|
|
|
- file_name:str,
|
|
|
- title:str,
|
|
|
- std=[],
|
|
|
+ def scatter(self,
|
|
|
+ x,
|
|
|
+ y: list,
|
|
|
+ labels: list,
|
|
|
+ figure_shape: tuple,
|
|
|
+ file_name: str,
|
|
|
+ title: str,
|
|
|
+ std=[],
|
|
|
true_values=[],
|
|
|
true_label='true',
|
|
|
- plot_legend=True,
|
|
|
+ plot_legend=True,
|
|
|
legend_loc='best',
|
|
|
xlabel='',
|
|
|
- ylabel='',
|
|
|
+ ylabel='',
|
|
|
xlabel_rotation=None):
|
|
|
assert len(y) == len(labels), f"There must be the same amount of labels as there are plots.\nNumber Plots: {len(y)}\nNumber Labels: {len(labels)}"
|
|
|
fig = self.__generate_figure(shape=figure_shape)
|
|
@@ -322,7 +323,7 @@ class Plotter:
|
|
|
|
|
|
markercycler = cycle(self.__marker_styles)
|
|
|
for i, array in enumerate(y):
|
|
|
-
|
|
|
+
|
|
|
if torch.is_tensor(array):
|
|
|
data = array.cpu().detach().numpy()
|
|
|
else:
|
|
@@ -332,11 +333,11 @@ class Plotter:
|
|
|
ax.scatter(x, data, label=labels[i], c=self.__colors[i % len(self.__colors)], linestyle='None', marker=next(markercycler))
|
|
|
if i < len(std):
|
|
|
ax.errorbar(x, data, std[i], label=labels[i], c=self.__colors[i % len(self.__colors)], linestyle='None', marker=next(markercycler))
|
|
|
-
|
|
|
+
|
|
|
linecycler = cycle(self.__lines_styles)
|
|
|
for i, true_value in enumerate(true_values):
|
|
|
- ax.plot(x, np.ones_like(x, dtype='float64') * true_value, label=f'{true_label} {labels[i]}',c=self.__colors[i % len(self.__colors)], ls=next(linecycler))
|
|
|
-
|
|
|
+ ax.plot(x, np.ones_like(x, dtype='float64') * true_value, label=f'{true_label} {labels[i]}', c=self.__colors[i % len(self.__colors)], ls=next(linecycler))
|
|
|
+
|
|
|
if plot_legend:
|
|
|
plt.legend(loc=legend_loc)
|
|
|
|