1
0

train_autoencoder.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import argparse
  2. import os
  3. from tqdm import tqdm
  4. import torch
  5. from torch import nn
  6. from torch.autograd import Variable
  7. from torch.utils.data import DataLoader
  8. from torchvision.utils import save_image
  9. from torchinfo import summary
  10. from py.PyTorchData import create_dataloader, model_output_to_image
  11. from py.Autoencoder2 import Autoencoder
  12. def train_autoencoder(model: nn.Module, train_dataloader: DataLoader, name: str, device: str = "cpu", num_epochs=100, criterion = nn.MSELoss(), lr: float = 1e-3, weight_decay: float = 1e-5, noise: bool = False):
  13. model = model.to(device)
  14. print(f"Using {device} device")
  15. optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
  16. print(f"Saving models to ./ae_train_NoBackup/{name}")
  17. os.makedirs(f"./ae_train_NoBackup/{name}", exist_ok=True)
  18. print(f"Training for {num_epochs} epochs.")
  19. for epoch in range(num_epochs):
  20. total_loss = 0
  21. for img, _ in tqdm(train_dataloader):
  22. img = Variable(img).to(device)
  23. input = img
  24. if noise:
  25. input = input + (0.01 ** 0.5) * torch.randn(img.size(), device=device)
  26. # ===================forward=====================
  27. output = model(input)
  28. loss = criterion(output, img)
  29. # ===================backward====================
  30. optimizer.zero_grad()
  31. loss.backward()
  32. optimizer.step()
  33. total_loss += loss.data
  34. # ===================log========================
  35. dsp_epoch = epoch + 1
  36. print('epoch [{}/{}], loss:{:.4f}'.format(dsp_epoch, num_epochs, total_loss))
  37. # log file
  38. with open(f"./ae_train_NoBackup/{name}/log.csv", "a+") as f:
  39. f.write(f"{dsp_epoch},{total_loss}\n")
  40. # output image
  41. if epoch % 2 == 0:
  42. pic = model_output_to_image(output.cpu().data)
  43. save_image(pic, f"./ae_train_NoBackup/{name}/image_{dsp_epoch:03d}.png")
  44. # model checkpoint
  45. if epoch % 5 == 0:
  46. torch.save(model.state_dict(), f"./ae_train_NoBackup/{name}/model_{dsp_epoch:03d}.pth")
  47. torch.save(model.state_dict(), f"./ae_train_NoBackup/{name}/model_{num_epochs:03d}.pth")
  48. if __name__ == "__main__":
  49. parser = argparse.ArgumentParser(description="Autoencoder train script")
  50. parser.add_argument("name", type=str, help="Name of the training session (name of the save folder)")
  51. parser.add_argument("img_folder", type=str, help="Path to directory containing train images (may contain subfolders)")
  52. parser.add_argument("--device", type=str, help="PyTorch device to train on (cpu or cuda)", default="cpu")
  53. parser.add_argument("--epochs", type=int, help="Number of epochs", default=100)
  54. parser.add_argument("--batch_size", type=int, help="Batch size (>=1)", default=32)
  55. parser.add_argument("--lr", type=float, help="Learning rate", default=1e-3)
  56. parser.add_argument("--image_transforms", action="store_true", help="Truncate and resize images (only enable if the input images have not been truncated resized to the target size already)")
  57. parser.add_argument("--noise", action="store_true", help="Add Gaussian noise to model input")
  58. args = parser.parse_args()
  59. if args.image_transforms:
  60. print("Image transforms enabled: Images will be truncated and resized.")
  61. else:
  62. print("Image transforms disabled: Images are expected to be of the right size.")
  63. data_loader = create_dataloader(args.img_folder, batch_size=args.batch_size, skip_transforms=not args.image_transforms)
  64. model = Autoencoder()
  65. print("Model:")
  66. summary(model, (args.batch_size, 3, 256, 256))
  67. print("Is CUDA available:", torch.cuda.is_available())
  68. print(f"Devices: ({torch.cuda.device_count()})")
  69. for i in range(torch.cuda.device_count()):
  70. print(torch.cuda.get_device_name(i))
  71. if args.noise:
  72. print("Adding Gaussian noise to model input")
  73. train_autoencoder(model, data_loader, args.name, device=args.device, num_epochs=args.epochs, lr=args.lr, noise=args.noise)