eval_autoencoder.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import argparse
  2. import os
  3. from tqdm import tqdm
  4. import numpy as np
  5. import torch
  6. from torch import nn
  7. from torch.autograd import Variable
  8. from torch.utils.data import DataLoader
  9. from torchvision.utils import save_image
  10. from torchinfo import summary
  11. from py.PyTorchData import create_dataloader, model_output_to_image
  12. from py.Autoencoder2 import Autoencoder
  13. def eval_autoencoder(model: Autoencoder, dataloader: DataLoader, name: str, set_name: str, device: str = "cpu", criterion = nn.MSELoss()):
  14. model = model.to(device)
  15. print(f"Using {device} device")
  16. print(f"Saving evaluation results to ./ae_train_NoBackup/{name}/eval")
  17. os.makedirs(f"./ae_train_NoBackup/{name}/eval", exist_ok=True)
  18. labels = []
  19. encodeds = []
  20. losses = []
  21. for img, labels in tqdm(dataloader):
  22. img_batch = Variable(img_batch).to(device)
  23. # ===================forward=====================
  24. encoded = model.encoder(img)
  25. encoded_flat = encoded.detach().numpy().reshape(encoded.size()[0], -1)
  26. output_batch = model.decoder(encoded)
  27. for input, output, label, enc_flat in zip(img, output_batch, labels, encoded_flat):
  28. losses.append(criterion(input, output))
  29. encodeds.append(enc_flat)
  30. labels.append(label)
  31. np.save(f"./ae_train_NoBackup/{name}/eval/{set_name}.npy")
  32. if __name__ == "__main__":
  33. parser = argparse.ArgumentParser(description="Autoencoder eval script")
  34. parser.add_argument("name", type=str, help="Name of the training session (name of the save folder)")
  35. parser.add_argument("model_name", type=str, help="Filename of the model (e.g. model_120.pth)")
  36. parser.add_argument("set_name", type=str, help="Name of the dataset (e.g. train or test)")
  37. parser.add_argument("img_folder", type=str, help="Path to directory containing train images (may contain subfolders)")
  38. parser.add_argument("--device", type=str, help="PyTorch device to train on (cpu or cuda)", default="cpu")
  39. parser.add_argument("--batch_size", type=int, help="Batch size (>=1)", default=32)
  40. parser.add_argument("--image_transforms", action="store_true", help="Truncate and resize images (only enable if the input images have not been truncated resized to the target size already)")
  41. args = parser.parse_args()
  42. if args.image_transforms:
  43. print("Image transforms enabled: Images will be truncated and resized.")
  44. else:
  45. print("Image transforms disabled: Images are expected to be of the right size.")
  46. dataloader = create_dataloader(args.img_folder, batch_size=args.batch_size, skip_transforms=not args.image_transforms)
  47. model = Autoencoder()
  48. print("Model:")
  49. summary(model, (args.batch_size, 3, 256, 256))
  50. print("Is CUDA available:", torch.cuda.is_available())
  51. print(f"Devices: ({torch.cuda.device_count()})")
  52. for i in range(torch.cuda.device_count()):
  53. print(torch.cuda.get_device_name(i))
  54. if args.noise:
  55. print("Adding Gaussian noise to model input")
  56. eval_autoencoder(model, dataloader, args.model_name, args.set_name, args.device)