1
0

PyTorchData.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. # Copyright (c) 2023 Felix Kleinsteuber and Computer Vision Group, Friedrich Schiller University Jena
  2. # Functions related to approach 4 (autoencoder).
  3. # For training and evaluation scripts, see ./train_autoencoder.py and ./eval_autoencoder.py.
  4. import os
  5. import matplotlib.pyplot as plt
  6. from torchvision import io, transforms
  7. from torch.utils.data import DataLoader, Dataset
  8. # PyTorch dataset instance which loads images from a directory
  9. class ImageDataset(Dataset):
  10. def __init__(self, img_dir: str, transform = None, labeler = None, filter = lambda filename: True):
  11. """Create a new PyTorch dataset from images in a directory.
  12. Args:
  13. img_dir (str): Source directory which contains the images.
  14. transform (lambda img: transformed_img, optional): Input transform function. Defaults to None.
  15. labeler (lambda str: int, optional): Labeling function. Input is the filename, output the label. Defaults to None.
  16. filter (lambda str: bool, optional): Input filter function. Input is the filename. Images where filter returns False are skipped. Defaults to no filtering.
  17. """
  18. self.img_dir = img_dir
  19. self.transform = transform
  20. self.labeler = labeler
  21. with os.scandir(img_dir) as it:
  22. self.files = [entry.name for entry in it if entry.name.endswith(".jpg") and entry.is_file() and filter(entry.name)]
  23. print(f"{len(self.files)} files found")
  24. def __len__(self):
  25. return len(self.files)
  26. def __getitem__(self, idx):
  27. img_path = os.path.join(self.img_dir, self.files[idx])
  28. img = io.read_image(img_path)
  29. # apply transform function
  30. if self.transform:
  31. img = self.transform(img)
  32. label = 0
  33. # get label
  34. if self.labeler:
  35. label = self.labeler(self.files[idx])
  36. return img, label
  37. def create_dataloader(img_folder: str, target_size: tuple = (256, 256), batch_size: int = 32, shuffle: bool = True, truncate_y: tuple = (40, 40), labeler = None, skip_transforms: bool = False, filter = lambda filename: True) -> DataLoader:
  38. """Creates a PyTorch DataLoader from the given image folder.
  39. Args:
  40. img_folder (str): Folder containing images. (All subfolders will be scanned for jpg images)
  41. target_size (tuple, optional): Model input size. Images are resized to this size. Defaults to (256, 256).
  42. batch_size (int, optional): Batch size. Defaults to 32.
  43. shuffle (bool, optional): Shuffle images. Good for training, useless for testing. Defaults to True.
  44. truncate_y (tuple, optional): (a, b), cut off the first a and the last b pixel rows of the unresized image. Defaults to (40, 40).
  45. labeler (lambda(filename: str) -> int, optional): Lambda that maps every filename to an int label. By default all labels are 0. Defaults to None.
  46. skip_transforms (bool, optional): Skip truncate and resize transforms. (If the images are already truncated and resized). Defaults to False.
  47. filter (lambda: str -> bool, optional): Additional filter by filename. Defaults to lambda filename: True.
  48. Returns:
  49. DataLoader: PyTorch DataLoader
  50. """
  51. def crop_lambda(img):
  52. return transforms.functional.crop(img, truncate_y[0], 0, img.shape[-2] - truncate_y[0] - truncate_y[1], img.shape[-1])
  53. transform = None
  54. if skip_transforms:
  55. transform = transforms.Compose([
  56. transforms.Lambda(lambda img: img.float()),
  57. transforms.Normalize((127.5), (127.5)) # min-max normalization to [-1, 1]
  58. ])
  59. else:
  60. transform = transforms.Compose([
  61. transforms.Lambda(crop_lambda),
  62. transforms.ToPILImage(),
  63. transforms.Resize(target_size),
  64. transforms.ToTensor(),
  65. transforms.Normalize((0.5), (0.5)) # min-max normalization to [-1, 1]
  66. ])
  67. data = ImageDataset(img_folder, transform=transform, labeler=labeler, filter=filter)
  68. return DataLoader(data, batch_size=batch_size, shuffle=shuffle)
  69. def model_output_to_image(y):
  70. """Converts the raw model output back to an image by normalizing and clamping it to [0, 1] and reshaping it.
  71. Args:
  72. y (PyTorch tensor): Autoencoder output.
  73. Returns:
  74. PyTorch tensor: Image from autoencoder output.
  75. """
  76. y = 0.5 * (y + 1) # normalize back to [0, 1]
  77. y = y.clamp(0, 1) # clamp to [0, 1]
  78. y = y.view(y.size(0), 3, 256, 256)
  79. return y
  80. def get_log(name: str, display: bool = False, figsize: tuple = (12, 6)):
  81. """Parses a training log file and returns the iteration and loss values.
  82. Args:
  83. name (str): Name of training session.
  84. display (bool, optional): If True, plot the training curve. Defaults to False.
  85. figsize (tuple, optional): Plot size if display is True. Defaults to (12, 6).
  86. Returns:
  87. iterations (list of int), losses (list of float): Training curve values
  88. """
  89. its = []
  90. losses = []
  91. with open(f"./ae_train_NoBackup/{name}/log.csv", "r") as f:
  92. for line in f:
  93. it, loss = line.rstrip().split(",")[:2]
  94. its.append(int(it))
  95. losses.append(float(loss))
  96. if display:
  97. plt.figure(figsize=figsize)
  98. plt.plot(its, losses)
  99. plt.title(f"Training curve ({name})")
  100. plt.xlabel("Epoch")
  101. plt.ylabel("MSE Loss")
  102. plt.show()
  103. return its, losses