123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119 |
- # Copyright (c) 2023 Felix Kleinsteuber and Computer Vision Group, Friedrich Schiller University Jena
- # Approach 4: Autoencoder
- # This script is used for training an autoencoder on Lapse images.
- # See eval_autoencoder.py for evaluation.
- import argparse
- import os
- from tqdm import tqdm
- import torch
- import numpy as np
- import random
- from torch import nn
- from torch.autograd import Variable
- from torch.utils.data import DataLoader
- from torchvision.utils import save_image
- from torchinfo import summary
- from py.PyTorchData import create_dataloader, model_output_to_image
- from py.Dataset import Dataset
- from py.Autoencoder2 import Autoencoder
- def train_autoencoder(model: Autoencoder, train_dataloader: DataLoader, name: str, device: str = "cpu", num_epochs=100, criterion = nn.MSELoss(), lr: float = 1e-3, weight_decay: float = 1e-5, noise: bool = False, sparse: bool = False, reg_rate: float = 1e-4, noise_var: float = 0.015):
- model = model.to(device)
- print(f"Using {device} device")
- optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
- print(f"Saving models to ./ae_train_NoBackup/{name}")
- os.makedirs(f"./ae_train_NoBackup/{name}", exist_ok=True)
- print(f"Training for {num_epochs} epochs.")
- for epoch in range(num_epochs):
- total_loss = 0
- total_reg_loss = 0
- for img, _ in tqdm(train_dataloader):
- optimizer.zero_grad()
- img = Variable(img).to(device)
- input = img
- if noise:
- input = input + (noise_var ** 0.5) * torch.randn(img.size(), device=device)
- # ===================forward=====================
- latent = model.encoder(input)
- output = model.decoder(latent)
- loss = criterion(output, img)
- total_loss += loss.item()
- if sparse:
- reg_loss = reg_rate * torch.mean(torch.abs(latent))
- total_reg_loss += reg_loss.item()
- loss += reg_loss
- # ===================backward====================
- loss.backward()
- optimizer.step()
- # ===================log========================
- dsp_epoch = epoch + 1
- if sparse:
- print('epoch [{}/{}], loss: {:.4f} + reg loss: {:.4f}'.format(dsp_epoch, num_epochs, total_loss, total_reg_loss))
- else:
- print('epoch [{}/{}], loss: {:.4f}'.format(dsp_epoch, num_epochs, total_loss))
-
- # log file
- with open(f"./ae_train_NoBackup/{name}/log.csv", "a+") as f:
- f.write(f"{dsp_epoch},{total_loss},{total_reg_loss}\n")
-
- # output image
- if epoch % 10 == 0:
- pic = model_output_to_image(output.cpu().data)
- save_image(pic, f"./ae_train_NoBackup/{name}/image_{dsp_epoch:03d}.png")
-
- # model checkpoint
- if epoch % 10 == 0:
- torch.save(model.state_dict(), f"./ae_train_NoBackup/{name}/model_{dsp_epoch:03d}.pth")
- torch.save(model.state_dict(), f"./ae_train_NoBackup/{name}/model_{num_epochs:03d}.pth")
- if __name__ == "__main__":
- parser = argparse.ArgumentParser(description="Autoencoder train script")
- parser.add_argument("name", type=str, help="Name of the training session (name of the save folder)")
- parser.add_argument("dataset_folder", type=str, help="Path to dataset folder containing sessions")
- parser.add_argument("session", type=str, help="Session name")
- parser.add_argument("--device", type=str, help="PyTorch device to train on (cpu or cuda)", default="cpu")
- parser.add_argument("--epochs", type=int, help="Number of epochs", default=100)
- parser.add_argument("--batch_size", type=int, help="Batch size (>=1)", default=32)
- parser.add_argument("--lr", type=float, help="Learning rate", default=1e-3)
- parser.add_argument("--reg_rate", type=float, help="Sparse regularization rate", default=1e-4)
- parser.add_argument("--dropout", type=float, help="Dropout rate on all layers", default=0.05)
- parser.add_argument("--latent", type=int, help="Number of latent features", default=512)
- parser.add_argument("--image_transforms", action="store_true", help="Truncate and resize images (only enable if the input images have not been truncated resized to the target size already)")
- parser.add_argument("--noise", action="store_true", help="Add Gaussian noise to model input")
- parser.add_argument("--noise_var", type=float, help="Noise variance", default=0.015)
- parser.add_argument("--sparse", action="store_true", help="Add L1 penalty to latent features")
- args = parser.parse_args()
- ds = Dataset(args.dataset_folder)
- session = ds.create_session(args.session)
- if args.image_transforms:
- print("Image transforms enabled: Images will be truncated and resized.")
- else:
- print("Image transforms disabled: Images are expected to be of the right size.")
- # torch.manual_seed(10810)
- # np.random.seed(10810)
- # random.seed(10810)
-
- data_loader = create_dataloader(session.get_lapse_folder(), batch_size=args.batch_size, skip_transforms=not args.image_transforms)
- model = Autoencoder(dropout=args.dropout, latent_features=args.latent)
- print("Model:")
- summary(model, (args.batch_size, 3, 256, 256))
- print("Is CUDA available:", torch.cuda.is_available())
- print(f"Devices: ({torch.cuda.device_count()})")
- for i in range(torch.cuda.device_count()):
- print(torch.cuda.get_device_name(i))
- if args.noise:
- print("Adding Gaussian noise to model input")
- if args.sparse:
- print("Adding L1 penalty to latent features (sparse)")
- train_autoencoder(model, data_loader, args.name, device=args.device, num_epochs=args.epochs, lr=args.lr, noise=args.noise, sparse=args.sparse, reg_rate=args.reg_rate, noise_var=args.noise_var)
|