1
0

eval_autoencoder.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. # Approach 4: Autoencoder
  2. # This script is used for evaluating an autoencoder on Motion and Lapse images.
  3. # See train_autoencoder.py for training.
  4. import argparse
  5. import os
  6. from glob import glob
  7. from tqdm import tqdm
  8. import numpy as np
  9. import torch
  10. from torch import nn
  11. from torch.autograd import Variable
  12. from torch.utils.data import DataLoader
  13. from py.FileUtils import dump
  14. from py.Dataset import Dataset
  15. from py.PyTorchData import create_dataloader
  16. from py.Autoencoder2 import Autoencoder
  17. from py.Labels import LABELS
  18. TRAIN_FOLDER = "./ae_train_NoBackup"
  19. def load_autoencoder(train_name: str, device: str = "cpu", model_number: int = -1):
  20. if model_number < 0:
  21. model_path = sorted(glob(f"./ae_train_NoBackup/{train_name}/model_*.pth"))[-1]
  22. else:
  23. model_path = f"./ae_train_NoBackup/{train_name}/model_{model_number:03d}.pth"
  24. print(f"Loading model from {model_path}... ", end="")
  25. model = Autoencoder()
  26. model.load_state_dict(torch.load(model_path, map_location=torch.device(device)))
  27. model.eval()
  28. print("Loaded!")
  29. return model
  30. def eval_autoencoder(model: Autoencoder, data_loader: DataLoader, device: str = "cpu"):
  31. losses = [] # reconstruction errors
  32. encodings = [] # latent representations for KDE
  33. labels = []
  34. with torch.no_grad():
  35. model = model.to(device)
  36. criterion = nn.MSELoss()
  37. for features, batch_labels in tqdm(data_loader):
  38. features = Variable(features).to(device)
  39. labels += batch_labels
  40. # forward
  41. encoded = model.encoder(features)
  42. output_batch = model.decoder(encoded)
  43. # Calculate and save encoded representation and loss
  44. encoded_flat = encoded.detach().cpu().numpy().reshape(encoded.size()[0], -1)
  45. for input, enc, output in zip(features, encoded_flat, output_batch):
  46. encodings.append(enc)
  47. losses.append(criterion(input, output).cpu().numpy())
  48. return np.array(losses), np.array(encodings), np.array(labels)
  49. def main():
  50. parser = argparse.ArgumentParser(description="Autoencoder eval script - evaluates Motion and Lapse images of session")
  51. parser.add_argument("name", type=str, help="Name of the training session (name of the save folder)")
  52. parser.add_argument("dataset_folder", type=str, help="Path to dataset folder containing sessions")
  53. parser.add_argument("session", type=str, help="Session name")
  54. parser.add_argument("--device", type=str, help="PyTorch device to train on (cpu or cuda)", default="cpu")
  55. parser.add_argument("--batch_size", type=int, help="Batch size (>=1)", default=32)
  56. parser.add_argument("--model_number", type=int, help="Load model save of specific epoch (default: use latest)", default=-1)
  57. parser.add_argument("--image_transforms", action="store_true", help="Truncate and resize images (only enable if the input images have not been truncated resized to the target size already)")
  58. args = parser.parse_args()
  59. if args.image_transforms:
  60. print("Image transforms enabled: Images will be truncated and resized.")
  61. else:
  62. print("Image transforms disabled: Images are expected to be of the right size.")
  63. ds = Dataset(args.dataset_folder)
  64. session = ds.create_session(args.session)
  65. # Target file names
  66. train_dir = os.path.join(TRAIN_FOLDER, args.name)
  67. save_dir = os.path.join(train_dir, "eval")
  68. os.makedirs(save_dir, exist_ok=True)
  69. lapse_eval_file = os.path.join(save_dir, f"{session.name}_lapse.pickle")
  70. motion_eval_file = os.path.join(save_dir, f"{session.name}_motion.pickle")
  71. # Load model
  72. model = load_autoencoder(args.name, args.device, args.model_number)
  73. # Check CUDA
  74. print("Is CUDA available:", torch.cuda.is_available())
  75. if torch.cuda.is_available() and args.device != "cuda":
  76. print("WARNING: CUDA is available but not activated! Use '--device cuda'.")
  77. print(f"Devices: ({torch.cuda.device_count()})")
  78. for i in range(torch.cuda.device_count()):
  79. print(torch.cuda.get_device_name(i))
  80. # Lapse eval
  81. if os.path.isfile(lapse_eval_file):
  82. print(f"Eval file for Lapse already exists ({lapse_eval_file}) Skipping Lapse evaluation...")
  83. else:
  84. print("Creating lapse data loader... ", end="")
  85. lapse_loader = create_dataloader(session.get_lapse_folder(), batch_size=args.batch_size, skip_transforms=not args.image_transforms, shuffle=False)
  86. results = eval_autoencoder(model, lapse_loader, args.device)
  87. dump(lapse_eval_file, results)
  88. print(f"Results saved to {lapse_eval_file}!")
  89. # Motion eval
  90. def is_labeled(filename: str) -> bool:
  91. img_nr = int(filename[-9:-4])
  92. return (img_nr <= LABELS[session.name]["max"]) and (img_nr not in LABELS[session.name]["not_annotated"])
  93. def labeler(filename: str) -> int:
  94. is_normal = (int(filename[-9:-4]) in LABELS[session.name]["normal"])
  95. return 0 if is_normal else 1
  96. if os.path.isfile(motion_eval_file):
  97. print(f"Eval file for Motion already exists ({motion_eval_file}) Skipping Motion evaluation...")
  98. else:
  99. print("Creating motion data loader... ", end="")
  100. motion_loader = create_dataloader(session.get_motion_folder(), batch_size=args.batch_size, skip_transforms=not args.image_transforms, shuffle=False, labeler=labeler, filter=is_labeled)
  101. results = eval_autoencoder(model, motion_loader, args.device)
  102. dump(motion_eval_file, results)
  103. print(f"Results saved to {motion_eval_file}!")
  104. print("Done.")
  105. if __name__ == "__main__":
  106. main()