PyTorchData.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. # Functions related to approach 4 (autoencoder).
  2. # For training and evaluation scripts, see ./train_autoencoder.py and ./eval_autoencoder.py.
  3. import os
  4. import matplotlib.pyplot as plt
  5. from torchvision import io, transforms
  6. from torch.utils.data import DataLoader, Dataset
  7. # PyTorch dataset instance which loads images from a directory
  8. class ImageDataset(Dataset):
  9. def __init__(self, img_dir: str, transform = None, labeler = None, filter = lambda filename: True):
  10. """Create a new PyTorch dataset from images in a directory.
  11. Args:
  12. img_dir (str): Source directory which contains the images.
  13. transform (lambda img: transformed_img, optional): Input transform function. Defaults to None.
  14. labeler (lambda str: int, optional): Labeling function. Input is the filename, output the label. Defaults to None.
  15. filter (lambda str: bool, optional): Input filter function. Input is the filename. Images where filter returns False are skipped. Defaults to no filtering.
  16. """
  17. self.img_dir = img_dir
  18. self.transform = transform
  19. self.labeler = labeler
  20. with os.scandir(img_dir) as it:
  21. self.files = [entry.name for entry in it if entry.name.endswith(".jpg") and entry.is_file() and filter(entry.name)]
  22. print(f"{len(self.files)} files found")
  23. def __len__(self):
  24. return len(self.files)
  25. def __getitem__(self, idx):
  26. img_path = os.path.join(self.img_dir, self.files[idx])
  27. img = io.read_image(img_path)
  28. # apply transform function
  29. if self.transform:
  30. img = self.transform(img)
  31. label = 0
  32. # get label
  33. if self.labeler:
  34. label = self.labeler(self.files[idx])
  35. return img, label
  36. def create_dataloader(img_folder: str, target_size: tuple = (256, 256), batch_size: int = 32, shuffle: bool = True, truncate_y: tuple = (40, 40), labeler = None, skip_transforms: bool = False, filter = lambda filename: True) -> DataLoader:
  37. """Creates a PyTorch DataLoader from the given image folder.
  38. Args:
  39. img_folder (str): Folder containing images. (All subfolders will be scanned for jpg images)
  40. target_size (tuple, optional): Model input size. Images are resized to this size. Defaults to (256, 256).
  41. batch_size (int, optional): Batch size. Defaults to 32.
  42. shuffle (bool, optional): Shuffle images. Good for training, useless for testing. Defaults to True.
  43. truncate_y (tuple, optional): (a, b), cut off the first a and the last b pixel rows of the unresized image. Defaults to (40, 40).
  44. labeler (lambda(filename: str) -> int, optional): Lambda that maps every filename to an int label. By default all labels are 0. Defaults to None.
  45. skip_transforms (bool, optional): Skip truncate and resize transforms. (If the images are already truncated and resized). Defaults to False.
  46. filter (lambda: str -> bool, optional): Additional filter by filename. Defaults to lambda filename: True.
  47. Returns:
  48. DataLoader: PyTorch DataLoader
  49. """
  50. def crop_lambda(img):
  51. return transforms.functional.crop(img, truncate_y[0], 0, img.shape[-2] - truncate_y[0] - truncate_y[1], img.shape[-1])
  52. transform = None
  53. if skip_transforms:
  54. transform = transforms.Compose([
  55. transforms.Lambda(lambda img: img.float()),
  56. transforms.Normalize((127.5), (127.5)) # min-max normalization to [-1, 1]
  57. ])
  58. else:
  59. transform = transforms.Compose([
  60. transforms.Lambda(crop_lambda),
  61. transforms.ToPILImage(),
  62. transforms.Resize(target_size),
  63. transforms.ToTensor(),
  64. transforms.Normalize((0.5), (0.5)) # min-max normalization to [-1, 1]
  65. ])
  66. data = ImageDataset(img_folder, transform=transform, labeler=labeler, filter=filter)
  67. return DataLoader(data, batch_size=batch_size, shuffle=shuffle)
  68. def model_output_to_image(y):
  69. """Converts the raw model output back to an image by normalizing and clamping it to [0, 1] and reshaping it.
  70. Args:
  71. y (PyTorch tensor): Autoencoder output.
  72. Returns:
  73. PyTorch tensor: Image from autoencoder output.
  74. """
  75. y = 0.5 * (y + 1) # normalize back to [0, 1]
  76. y = y.clamp(0, 1) # clamp to [0, 1]
  77. y = y.view(y.size(0), 3, 256, 256)
  78. return y
  79. def get_log(name: str, display: bool = False, figsize: tuple = (12, 6)):
  80. """Parses a training log file and returns the iteration and loss values.
  81. Args:
  82. name (str): Name of training session.
  83. display (bool, optional): If True, plot the training curve. Defaults to False.
  84. figsize (tuple, optional): Plot size if display is True. Defaults to (12, 6).
  85. Returns:
  86. iterations (list of int), losses (list of float): Training curve values
  87. """
  88. its = []
  89. losses = []
  90. with open(f"./ae_train_NoBackup/{name}/log.csv", "r") as f:
  91. for line in f:
  92. it, loss = line.rstrip().split(",")[:2]
  93. its.append(int(it))
  94. losses.append(float(loss))
  95. if display:
  96. plt.figure(figsize=figsize)
  97. plt.plot(its, losses)
  98. plt.title(f"Training curve ({name})")
  99. plt.xlabel("Epoch")
  100. plt.ylabel("MSE Loss")
  101. plt.show()
  102. return its, losses