train_autoencoder.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. # Copyright (c) 2023 Felix Kleinsteuber and Computer Vision Group, Friedrich Schiller University Jena
  2. # Approach 4: Autoencoder
  3. # This script is used for training an autoencoder on Lapse images.
  4. # See eval_autoencoder.py for evaluation.
  5. import argparse
  6. import os
  7. from tqdm import tqdm
  8. import torch
  9. import numpy as np
  10. import random
  11. from torch import nn
  12. from torch.autograd import Variable
  13. from torch.utils.data import DataLoader
  14. from torchvision.utils import save_image
  15. from torchinfo import summary
  16. from py.PyTorchData import create_dataloader, model_output_to_image
  17. from py.Dataset import Dataset
  18. from py.Autoencoder2 import Autoencoder
  19. def train_autoencoder(model: Autoencoder, train_dataloader: DataLoader, name: str, device: str = "cpu", num_epochs=100, criterion = nn.MSELoss(), lr: float = 1e-3, weight_decay: float = 1e-5, noise: bool = False, sparse: bool = False, reg_rate: float = 1e-4, noise_var: float = 0.015):
  20. model = model.to(device)
  21. print(f"Using {device} device")
  22. optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
  23. print(f"Saving models to ./ae_train_NoBackup/{name}")
  24. os.makedirs(f"./ae_train_NoBackup/{name}", exist_ok=True)
  25. print(f"Training for {num_epochs} epochs.")
  26. for epoch in range(num_epochs):
  27. total_loss = 0
  28. total_reg_loss = 0
  29. for img, _ in tqdm(train_dataloader):
  30. optimizer.zero_grad()
  31. img = Variable(img).to(device)
  32. input = img
  33. if noise:
  34. input = input + (noise_var ** 0.5) * torch.randn(img.size(), device=device)
  35. # ===================forward=====================
  36. latent = model.encoder(input)
  37. output = model.decoder(latent)
  38. loss = criterion(output, img)
  39. total_loss += loss.item()
  40. if sparse:
  41. reg_loss = reg_rate * torch.mean(torch.abs(latent))
  42. total_reg_loss += reg_loss.item()
  43. loss += reg_loss
  44. # ===================backward====================
  45. loss.backward()
  46. optimizer.step()
  47. # ===================log========================
  48. dsp_epoch = epoch + 1
  49. if sparse:
  50. print('epoch [{}/{}], loss: {:.4f} + reg loss: {:.4f}'.format(dsp_epoch, num_epochs, total_loss, total_reg_loss))
  51. else:
  52. print('epoch [{}/{}], loss: {:.4f}'.format(dsp_epoch, num_epochs, total_loss))
  53. # log file
  54. with open(f"./ae_train_NoBackup/{name}/log.csv", "a+") as f:
  55. f.write(f"{dsp_epoch},{total_loss},{total_reg_loss}\n")
  56. # output image
  57. if epoch % 10 == 0:
  58. pic = model_output_to_image(output.cpu().data)
  59. save_image(pic, f"./ae_train_NoBackup/{name}/image_{dsp_epoch:03d}.png")
  60. # model checkpoint
  61. if epoch % 10 == 0:
  62. torch.save(model.state_dict(), f"./ae_train_NoBackup/{name}/model_{dsp_epoch:03d}.pth")
  63. torch.save(model.state_dict(), f"./ae_train_NoBackup/{name}/model_{num_epochs:03d}.pth")
  64. if __name__ == "__main__":
  65. parser = argparse.ArgumentParser(description="Autoencoder train script")
  66. parser.add_argument("name", type=str, help="Name of the training session (name of the save folder)")
  67. parser.add_argument("dataset_folder", type=str, help="Path to dataset folder containing sessions")
  68. parser.add_argument("session", type=str, help="Session name")
  69. parser.add_argument("--device", type=str, help="PyTorch device to train on (cpu or cuda)", default="cpu")
  70. parser.add_argument("--epochs", type=int, help="Number of epochs", default=100)
  71. parser.add_argument("--batch_size", type=int, help="Batch size (>=1)", default=32)
  72. parser.add_argument("--lr", type=float, help="Learning rate", default=1e-3)
  73. parser.add_argument("--reg_rate", type=float, help="Sparse regularization rate", default=1e-4)
  74. parser.add_argument("--dropout", type=float, help="Dropout rate on all layers", default=0.05)
  75. parser.add_argument("--latent", type=int, help="Number of latent features", default=512)
  76. parser.add_argument("--image_transforms", action="store_true", help="Truncate and resize images (only enable if the input images have not been truncated resized to the target size already)")
  77. parser.add_argument("--noise", action="store_true", help="Add Gaussian noise to model input")
  78. parser.add_argument("--noise_var", type=float, help="Noise variance", default=0.015)
  79. parser.add_argument("--sparse", action="store_true", help="Add L1 penalty to latent features")
  80. args = parser.parse_args()
  81. ds = Dataset(args.dataset_folder)
  82. session = ds.create_session(args.session)
  83. if args.image_transforms:
  84. print("Image transforms enabled: Images will be truncated and resized.")
  85. else:
  86. print("Image transforms disabled: Images are expected to be of the right size.")
  87. # torch.manual_seed(10810)
  88. # np.random.seed(10810)
  89. # random.seed(10810)
  90. data_loader = create_dataloader(session.get_lapse_folder(), batch_size=args.batch_size, skip_transforms=not args.image_transforms)
  91. model = Autoencoder(dropout=args.dropout, latent_features=args.latent)
  92. print("Model:")
  93. summary(model, (args.batch_size, 3, 256, 256))
  94. print("Is CUDA available:", torch.cuda.is_available())
  95. print(f"Devices: ({torch.cuda.device_count()})")
  96. for i in range(torch.cuda.device_count()):
  97. print(torch.cuda.get_device_name(i))
  98. if args.noise:
  99. print("Adding Gaussian noise to model input")
  100. if args.sparse:
  101. print("Adding L1 penalty to latent features (sparse)")
  102. train_autoencoder(model, data_loader, args.name, device=args.device, num_epochs=args.epochs, lr=args.lr, noise=args.noise, sparse=args.sparse, reg_rate=args.reg_rate, noise_var=args.noise_var)