phillip.rothenbeck hai 4 meses
pai
achega
f2ced484b6
Modificáronse 3 ficheiros con 159 adicións e 156 borrados
  1. 32 30
      src/dataset.py
  2. 71 70
      src/plotter.py
  3. 56 56
      states_training.py

+ 32 - 30
src/dataset.py

@@ -1,20 +1,23 @@
 import torch
+import numpy as np
 from enum import Enum
 
+
 class Norms(Enum):
-    POPULATION=0
-    MIN_MAX=1
-    CONSTANT=2
+    POPULATION = 0
+    MIN_MAX = 1
+    CONSTANT = 2
+
 
 class PandemicDataset:
-    def __init__(self, 
-                 name:str,
-                 group_names:list, 
-                 N: int, 
-                 t, 
-                 *groups, 
+    def __init__(self,
+                 name: str,
+                 group_names: list,
+                 N: int,
+                 t,
+                 *groups,
                  norm_name=Norms.MIN_MAX,
-                 C = 10**5,
+                 C=10**5,
                  use_scaled_time=False):
         """Class to hold all data for one training process.
 
@@ -60,12 +63,12 @@ class PandemicDataset:
 
         self.__group_dict = {}
         for i, name in enumerate(group_names):
-            self.__group_dict.update({name : i})
+            self.__group_dict.update({name: i})
 
         self.__group_names = group_names
 
         self.__groups = [torch.tensor(group, device=self.device_name) for group in groups]
-        
+
         self.__mins = [torch.min(group) for group in self.__groups]
         self.__maxs = [torch.max(group) for group in self.__groups]
         self.__norms = self.__norm(self.__groups)
@@ -73,50 +76,49 @@ class PandemicDataset:
     @property
     def number_groups(self):
         return len(self.__group_names)
-    
+
     @property
     def data(self):
         return self.__groups
-    
+
     @property
-    def group_names(self):
+    def group_names(self) -> np.ndarray:
         return self.__group_names
-    
+
     def __population_norm(self, data):
         return [(data[i] / self.N) for i in range(self.number_groups)]
-    
+
     def __population_denorm(self, data):
         return [(data[i] * self.N) for i in range(self.number_groups)]
 
     def __min_max_norm(self, data):
         return [(data[i] - self.__mins[i]) / (self.__maxs[i] - self.__mins[i]) for i in range(self.number_groups)]
-    
+
     def __min_max_denorm(self, data):
         return [(self.__mins[i] + (self.__maxs[i] - self.__mins[i]) * data[i]) for i in range(self.number_groups)]
-    
+
     def __constant_norm(self, data):
         return [(data[i] / self.C) for i in range(self.number_groups)]
 
     def __constant_denorm(self, data):
         return [(data[i] * self.C) for i in range(self.number_groups)]
 
-    def get_normalized_data(self, data:list):
+    def get_normalized_data(self, data: list):
         assert len(data) == self.number_groups, f'data parameter needs same length as there are groups in the dataset ({self.number_groups})'
         return self.__norm(data)
-    
-    def get_denormalized_data(self, data:list):
+
+    def get_denormalized_data(self, data: list):
         assert len(data) == self.number_groups, f'data parameter needs same length as there are groups in the dataset ({self.number_groups})'
         return self.__denorm(data)
 
-    def get_group(self, name:str):
+    def get_group(self, name: str):
         return self.__groups[self.__group_dict[name]]
-    
-    def get_min(self, name:str):
+
+    def get_min(self, name: str):
         return self.__mins[self.__group_dict[name]]
-    
-    def get_max(self, name:str):
+
+    def get_max(self, name: str):
         return self.__maxs[self.__group_dict[name]]
-    
-    def get_norm(self, name:str):
+
+    def get_norm(self, name: str):
         return self.__norms[self.__group_dict[name]]
-    

+ 71 - 70
src/plotter.py

@@ -3,7 +3,7 @@ import torch
 import imageio
 import numpy as np
 import matplotlib.pyplot as plt
-import matplotlib.ticker as ticker 
+import matplotlib.ticker as ticker
 
 from matplotlib import rcParams
 from itertools import cycle
@@ -15,6 +15,7 @@ SUSCEPTIBLE_COLOR = '#6399f7'
 INFECTIOUS_COLOR = '#f56262'
 REMOVED_COLOR = '#83eb5e'
 
+
 class Plotter:
     def __init__(self, additional_colors=[], font_size=20, font='serif', font_color='#000000') -> None:
         """Plotter of scientific plots and animations, for dinn.py.
@@ -65,27 +66,27 @@ class Plotter:
         """Delete all frames that were saved.
         """
         self.__frames = []
-    
-    #TODO comments
-    def plot(self, 
-             x, 
-             y:list, 
-             labels:list,
-             file_name:str,
-             title:str, 
-             figure_shape:tuple, 
+
+    # TODO comments
+    def plot(self,
+             x,
+             y: list,
+             labels: list,
+             file_name: str,
+             title: str,
+             figure_shape: tuple,
              event_lookup={},
-             is_frame=False, 
-             is_background=[], 
+             is_frame=False,
+             is_background=[],
              fill_between=[],
-             plot_legend=True, 
-             y_log_scale=False, 
-             lw=3, 
-             legend_loc='best', 
+             plot_legend=True,
+             y_log_scale=False,
+             lw=3,
+             legend_loc='best',
              ylim=(None, None),
-             number_xlabels = 5,
+             number_xlabels=5,
              xlabel='',
-             ylabel='', 
+             ylabel='',
              xlabel_rotation=None):
         """Plotting method.
 
@@ -112,14 +113,14 @@ class Plotter:
 
         ax = fig.add_subplot(111, facecolor='#dddddd', axisbelow=True)
         ax.set_facecolor('xkcd:white')
-        #ax.yaxis.set_tick_params(length=0, which='both')
-        #ax.xaxis.set_tick_params(length=0, which='both')
+        # ax.yaxis.set_tick_params(length=0, which='both')
+        # ax.xaxis.set_tick_params(length=0, which='both')
 
-        #ax.grid(which='major', c='black', lw=0.2, ls='-')
+        # ax.grid(which='major', c='black', lw=0.2, ls='-')
         ax.set_title(title)
 
-        #for spine in ('top', 'right', 'bottom', 'left'):
-         #   ax.spines[spine].set_visible(False)
+        # for spine in ('top', 'right', 'bottom', 'left'):
+        #   ax.spines[spine].set_visible(False)
 
         if torch.is_tensor(x):
             x = x.cpu().detach().numpy()
@@ -131,24 +132,24 @@ class Plotter:
             if len(is_background) != 0:
                 if is_background[i]:
                     alpha = 0.25
-            
+
             if torch.is_tensor(array):
                 data = array.cpu().detach().numpy()
             else:
                 data = array
 
             space = int(len(x) / number_xlabels)
-            ax.xaxis.set_major_locator(ticker.MultipleLocator(space)) 
-        
+            ax.xaxis.set_major_locator(ticker.MultipleLocator(space))
+
             ax.plot(x, data, label=labels[i], alpha=alpha, lw=lw, linestyle=next(linecycler), c=self.__colors[i % len(self.__colors)])
             if i < len(fill_between):
-                ax.fill_between(x, data+fill_between[i], data-fill_between[i], facecolor=self.__colors[i % len(self.__colors)], alpha=0.5)
+                ax.fill_between(x, data + fill_between[i], data - fill_between[i], facecolor=self.__colors[i % len(self.__colors)], alpha=0.5)
             j = i
 
         for event in event_lookup.keys():
             j += 1
-            plt.axvline(x = event_lookup[event], color = self.__colors[j % len(self.__colors)], label = event, ls=next(linecycler))
-            
+            plt.axvline(x=event_lookup[event], color=self.__colors[j % len(self.__colors)], label=event, ls=next(linecycler))
+
         if plot_legend:
             plt.legend(loc=legend_loc)
 
@@ -166,7 +167,7 @@ class Plotter:
 
         if xlabel_rotation != None:
             plt.xticks(rotation=xlabel_rotation)
- 
+
         if not os.path.exists(FRAME_DIR):
             os.makedirs(FRAME_DIR)
 
@@ -181,29 +182,29 @@ class Plotter:
 
         plt.close()
 
-    def cluster_plot(self, 
-                     x, 
-                     y:list, 
-                     labels, 
-                     shape, 
-                     plots_shape, 
-                     file_name:str, 
-                     titles:list, 
+    def cluster_plot(self,
+                     x,
+                     y: list,
+                     labels,
+                     shape,
+                     plots_shape,
+                     file_name: str,
+                     titles: list,
                      number_xlabels=5,
-                     lw=3, 
+                     lw=3,
                      fill_between=[],
                      event_lookup={},
                      xlabel='',
                      ylabel='',
                      ylim=(None, None),
                      y_lim_exception=None,
-                     y_log_scale=False, 
-                     legend_loc=(0.5,0.992),
+                     y_log_scale=False,
+                     legend_loc=(0.5, 0.992),
                      add_y_space=0.05,
                      number_of_legend_columns=1,
                      same_axes=True,
                      free_axis=(None, None),
-                     plot_all_labels = True):
+                     plot_all_labels=True):
         real_shape = (shape[1] * plots_shape[0], shape[0] * plots_shape[1])
         fig, axes = plt.subplots(*shape, figsize=real_shape, sharex=same_axes, sharey=same_axes)
         plot_idx = 0
@@ -224,25 +225,25 @@ class Plotter:
 
                     space = int(len(x) / number_xlabels)
                     axes[i].xaxis.set_major_locator(ticker.MultipleLocator(space))
-                    axes[i].plot(x, 
-                                 data, 
+                    axes[i].plot(x,
+                                 data,
                                  linestyle=next(linecycler),
                                  label=labels[j],
                                  c=color,
                                  lw=lw)
                     axes[i].set_title(titles[i])
                     if j < len(fill_between[i]):
-                        axes[i].fill_between(x, data+fill_between[i][j], data-fill_between[i][j], facecolor=color, alpha=0.5)
+                        axes[i].fill_between(x, data + fill_between[i][j], data - fill_between[i][j], facecolor=color, alpha=0.5)
                 for event in event_lookup.keys():
-                    axes[i].axvline(x=event_lookup[event], 
-                                    color=next(colorcycler), 
-                                    label=event, 
+                    axes[i].axvline(x=event_lookup[event],
+                                    color=next(colorcycler),
+                                    label=event,
                                     ls=next(linecycler),
                                     lw=lw)
-            
+
                 if ylim[0] != None and y_lim_exception != i:
                     axes[i].set_ylim(ylim)
-                
+
                 if y_log_scale:
                     plt.yscale('log')
         else:
@@ -265,11 +266,11 @@ class Plotter:
                                     c = len(data)
                                 else:
                                     c = len(x)
-                                axes[i, j].plot(x[:c], 
-                                                data, 
-                                                label=labels[k], 
-                                                c=next(colorcycler), 
-                                                lw=lw, 
+                                axes[i, j].plot(x[:c],
+                                                data,
+                                                label=labels[k],
+                                                c=next(colorcycler),
+                                                lw=lw,
                                                 linestyle=next(linecycler))
                             axes[i, j].set_title(titles[plot_idx])
                             if ylim[0] != None:
@@ -292,23 +293,23 @@ class Plotter:
                 ax.label_outer()
 
         # Adjust layout to prevent overlap
-        plt.tight_layout(rect=[0, 0, 1, 1-add_y_space])
+        plt.tight_layout(rect=[0, 0, 1, 1 - add_y_space])
         plt.savefig(VISUALISATION_DIR + f'{file_name}.pdf')
 
-    def scatter(self, 
-                x, 
-                y:list, 
-                labels:list, 
-                figure_shape:tuple, 
-                file_name:str, 
-                title:str, 
-                std=[], 
+    def scatter(self,
+                x,
+                y: list,
+                labels: list,
+                figure_shape: tuple,
+                file_name: str,
+                title: str,
+                std=[],
                 true_values=[],
                 true_label='true',
-                plot_legend=True, 
+                plot_legend=True,
                 legend_loc='best',
                 xlabel='',
-                ylabel='', 
+                ylabel='',
                 xlabel_rotation=None):
         assert len(y) == len(labels), f"There must be the same amount of labels as there are plots.\nNumber Plots: {len(y)}\nNumber Labels: {len(labels)}"
         fig = self.__generate_figure(shape=figure_shape)
@@ -322,7 +323,7 @@ class Plotter:
 
         markercycler = cycle(self.__marker_styles)
         for i, array in enumerate(y):
-            
+
             if torch.is_tensor(array):
                 data = array.cpu().detach().numpy()
             else:
@@ -332,11 +333,11 @@ class Plotter:
                 ax.scatter(x, data, label=labels[i], c=self.__colors[i % len(self.__colors)], linestyle='None', marker=next(markercycler))
             if i < len(std):
                 ax.errorbar(x, data, std[i], label=labels[i], c=self.__colors[i % len(self.__colors)], linestyle='None', marker=next(markercycler))
-        
+
         linecycler = cycle(self.__lines_styles)
         for i, true_value in enumerate(true_values):
-            ax.plot(x, np.ones_like(x, dtype='float64') * true_value, label=f'{true_label} {labels[i]}',c=self.__colors[i % len(self.__colors)], ls=next(linecycler))
-            
+            ax.plot(x, np.ones_like(x, dtype='float64') * true_value, label=f'{true_label} {labels[i]}', c=self.__colors[i % len(self.__colors)], ls=next(linecycler))
+
         if plot_legend:
             plt.legend(loc=legend_loc)
 

+ 56 - 56
states_training.py

@@ -7,7 +7,7 @@ from src.dataset import PandemicDataset, Norms
 from src.problem import ReducedSIRProblem
 from src.dinn import DINN, Scheduler, Activation
 
-ALPHA = [1/14, 1/5]
+ALPHA = [1 / 14, 1 / 5]
 DO_STATES = True
 DO_SYNTHETIC = False
 
@@ -16,66 +16,66 @@ ITERATIONS = 13
 state_starting_index = 0
 
 if "1" in sys.argv:
-     state_starting_index = 8
-
-STATE_LOOKUP = {'Schleswig_Holstein' : 2897000,
-                'Hamburg' : 1841000, 
-                'Niedersachsen' : 7982000, 
-                'Bremen' : 569352,
-                'Nordrhein_Westfalen' : 17930000,
-                'Hessen' : 6266000,
-                'Rheinland_Pfalz' : 4085000,
-                'Baden_Wuerttemberg' : 11070000,
-                'Bayern' : 13080000,
-                'Saarland' : 990509,
-                'Berlin' : 3645000,
-                'Brandenburg' : 2641000,
-                'Mecklenburg_Vorpommern' : 1610000,
-                'Sachsen' : 4078000,
-                'Sachsen_Anhalt' : 2208000,
-                'Thueringen' : 2143000}
+    state_starting_index = 8
+
+STATE_LOOKUP = {'Schleswig_Holstein': 2897000,
+                'Hamburg': 1841000,
+                'Niedersachsen': 7982000,
+                'Bremen': 569352,
+                'Nordrhein_Westfalen': 17930000,
+                'Hessen': 6266000,
+                'Rheinland_Pfalz': 4085000,
+                'Baden_Wuerttemberg': 11070000,
+                'Bayern': 13080000,
+                'Saarland': 990509,
+                'Berlin': 3645000,
+                'Brandenburg': 2641000,
+                'Mecklenburg_Vorpommern': 1610000,
+                'Sachsen': 4078000,
+                'Sachsen_Anhalt': 2208000,
+                'Thueringen': 2143000}
 
 if DO_SYNTHETIC:
-    alpha = 1/3
+    alpha = 1 / 3
     covid_data = np.genfromtxt(f'./datasets/I_data.csv', delimiter=',')
     for i in range(ITERATIONS):
-        dataset = PandemicDataset('Synthetic I', 
-                                    ['I'], 
-                                    7.6e6, 
-                                    *covid_data, 
-                                    norm_name=Norms.CONSTANT, 
-                                    use_scaled_time=True)
+        dataset = PandemicDataset('Synthetic I',
+                                  ['I'],
+                                  7.6e6,
+                                  *covid_data,
+                                  norm_name=Norms.CONSTANT,
+                                  use_scaled_time=True)
 
         problem = ReducedSIRProblem(dataset, alpha)
-        dinn = DINN(2, 
-                    dataset, 
-                    [], 
-                    problem, 
-                    None, 
-                    state_variables=['R_t'], 
-                    hidden_size=100, 
-                    hidden_layers=4, 
+        dinn = DINN(2,
+                    dataset,
+                    [],
+                    problem,
+                    None,
+                    state_variables=['R_t'],
+                    hidden_size=100,
+                    hidden_layers=4,
                     activation_layer=torch.nn.Tanh(),
                     activation_output=Activation.POWER)
 
-        dinn.configure_training(1e-3, 
-                                20000, 
-                                scheduler_class=Scheduler.POLYNOMIAL, 
+        dinn.configure_training(1e-3,
+                                20000,
+                                scheduler_class=Scheduler.POLYNOMIAL,
                                 lambda_physics=1e-6,
                                 verbose=True)
         dinn.train(verbose=True, do_split_training=True)
 
         dinn.save_training_process(f'synthetic_{i}')
-        #r_t = dinn.get_output(1).detach().cpu().numpy()
+        # r_t = dinn.get_output(1).detach().cpu().numpy()
 
-        #with open(f'./results/synthetic_{i}.csv', 'w', newline='') as csvfile:
-                #writer = csv.writer(csvfile, delimiter=',')
-                #writer.writerow(r_t)
+        # with open(f'./results/synthetic_{i}.csv', 'w', newline='') as csvfile:
+        # writer = csv.writer(csvfile, delimiter=',')
+        # writer.writerow(r_t)
 
 for iteration in range(ITERATIONS):
     if iteration <= 2:
-         print('skip first three iteration, as it was already done')
-         continue
+        print('skip first three iteration, as it was already done')
+        continue
     if DO_STATES:
         for state_idx in range(state_starting_index, state_starting_index + 8):
             state = list(STATE_LOOKUP.keys())[state_idx]
@@ -91,22 +91,22 @@ for iteration in range(ITERATIONS):
 
                 problem = ReducedSIRProblem(dataset, alpha)
 
-                dinn = DINN(2, 
-                            dataset, 
-                            [], 
-                            problem, 
-                            None, 
-                            state_variables=['R_t'], 
-                            hidden_size=100, 
-                            hidden_layers=4, 
+                dinn = DINN(2,
+                            dataset,
+                            [],
+                            problem,
+                            None,
+                            state_variables=['R_t'],
+                            hidden_size=100,
+                            hidden_layers=4,
                             activation_layer=torch.nn.Tanh(),
                             activation_output=Activation.POWER)
 
-                dinn.configure_training(1e-3, 
-                                        25000, 
+                dinn.configure_training(1e-3,
+                                        25000,
                                         scheduler_class=Scheduler.POLYNOMIAL,
-                                        lambda_obs=1e2, 
-                                        lambda_physics=1e-6, 
+                                        lambda_obs=1e2,
+                                        lambda_physics=1e-6,
                                         verbose=True)
                 dinn.train(verbose=True, do_split_training=True)
 
@@ -115,4 +115,4 @@ for iteration in range(ITERATIONS):
                 r_t = dinn.get_output(1).detach().cpu().numpy()
                 with open(f'./results/{state}_{i}_{iteration}.csv', 'w', newline='') as csvfile:
                     writer = csv.writer(csvfile, delimiter=',')
-                    writer.writerow(r_t)
+                    writer.writerow(r_t)