123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869 |
- import argparse
- import os
- from tqdm import tqdm
- import numpy as np
- import torch
- from torch import nn
- from torch.autograd import Variable
- from torch.utils.data import DataLoader
- from torchvision.utils import save_image
- from torchinfo import summary
- from py.PyTorchData import create_dataloader, model_output_to_image
- from py.Autoencoder2 import Autoencoder
- def eval_autoencoder(model: Autoencoder, dataloader: DataLoader, name: str, set_name: str, device: str = "cpu", criterion = nn.MSELoss()):
- model = model.to(device)
- print(f"Using {device} device")
- print(f"Saving evaluation results to ./ae_train_NoBackup/{name}/eval")
- os.makedirs(f"./ae_train_NoBackup/{name}/eval", exist_ok=True)
- labels = []
- encodeds = []
- losses = []
- for img, labels in tqdm(dataloader):
- img_batch = Variable(img_batch).to(device)
- # ===================forward=====================
- encoded = model.encoder(img)
- encoded_flat = encoded.detach().numpy().reshape(encoded.size()[0], -1)
- output_batch = model.decoder(encoded)
- for input, output, label, enc_flat in zip(img, output_batch, labels, encoded_flat):
- losses.append(criterion(input, output))
- encodeds.append(enc_flat)
- labels.append(label)
- np.save(f"./ae_train_NoBackup/{name}/eval/{set_name}.npy")
- if __name__ == "__main__":
- parser = argparse.ArgumentParser(description="Autoencoder eval script")
- parser.add_argument("name", type=str, help="Name of the training session (name of the save folder)")
- parser.add_argument("model_name", type=str, help="Filename of the model (e.g. model_120.pth)")
- parser.add_argument("set_name", type=str, help="Name of the dataset (e.g. train or test)")
- parser.add_argument("img_folder", type=str, help="Path to directory containing train images (may contain subfolders)")
- parser.add_argument("--device", type=str, help="PyTorch device to train on (cpu or cuda)", default="cpu")
- parser.add_argument("--batch_size", type=int, help="Batch size (>=1)", default=32)
- parser.add_argument("--image_transforms", action="store_true", help="Truncate and resize images (only enable if the input images have not been truncated resized to the target size already)")
-
- args = parser.parse_args()
- if args.image_transforms:
- print("Image transforms enabled: Images will be truncated and resized.")
- else:
- print("Image transforms disabled: Images are expected to be of the right size.")
-
- dataloader = create_dataloader(args.img_folder, batch_size=args.batch_size, skip_transforms=not args.image_transforms)
- model = Autoencoder()
- print("Model:")
- summary(model, (args.batch_size, 3, 256, 256))
- print("Is CUDA available:", torch.cuda.is_available())
- print(f"Devices: ({torch.cuda.device_count()})")
- for i in range(torch.cuda.device_count()):
- print(torch.cuda.get_device_name(i))
- if args.noise:
- print("Adding Gaussian noise to model input")
- eval_autoencoder(model, dataloader, args.model_name, args.set_name, args.device)
|