PyTorchData.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import os
  2. import matplotlib.pyplot as plt
  3. from torchvision import io, transforms
  4. from torch.utils.data import DataLoader, Dataset
  5. class ImageDataset(Dataset):
  6. def __init__(self, img_dir: str, transform = None, labeler = None):
  7. self.img_dir = img_dir
  8. self.transform = transform
  9. self.labeler = labeler
  10. with os.scandir(img_dir) as it:
  11. self.files = [entry.name for entry in it if entry.name.endswith(".jpg") and entry.is_file()]
  12. def __len__(self):
  13. return len(self.files)
  14. def __getitem__(self, idx):
  15. img_path = os.path.join(self.img_dir, self.files[idx])
  16. img = io.read_image(img_path)
  17. if self.transform:
  18. img = self.transform(img)
  19. label = 0
  20. if self.labeler:
  21. label = self.labeler(self.files[idx])
  22. return img, label
  23. def create_dataloader(img_folder: str, target_size: tuple = (256, 256), batch_size: int = 32, shuffle: bool = True, truncate_y: tuple = (40, 40), labeler = None, skip_transforms: bool = False) -> DataLoader:
  24. """Creates a PyTorch DataLoader from the given image folder.
  25. Args:
  26. img_folder (str): Folder containing images. (All subfolders will be scanned for jpg images)
  27. target_size (tuple, optional): Model input size. Images are resized to this size. Defaults to (256, 256).
  28. batch_size (int, optional): Batch size. Defaults to 32.
  29. shuffle (bool, optional): Shuffle images. Good for training, useless for testing. Defaults to True.
  30. truncate_y (tuple, optional): (a, b), cut off the first a and the last b pixel rows of the unresized image. Defaults to (40, 40).
  31. labeler (lambda(filename: str) -> int, optional): Lambda that maps every filename to an int label. By default all labels are 0. Defaults to None.
  32. skip_transforms (bool, optional): Skip truncate and resize transforms. (If the images are already truncated and resized). Defaults to False.
  33. Returns:
  34. DataLoader: PyTorch DataLoader
  35. """
  36. def crop_lambda(img):
  37. return transforms.functional.crop(img, truncate_y[0], 0, img.shape[-2] - truncate_y[0] - truncate_y[1], img.shape[-1])
  38. transform = None
  39. if skip_transforms:
  40. transform = transforms.Compose([
  41. transforms.Lambda(lambda img: img.float()),
  42. transforms.Normalize((127.5), (127.5)) # min-max normalization to [-1, 1]
  43. ])
  44. else:
  45. transform = transforms.Compose([
  46. transforms.Lambda(crop_lambda),
  47. transforms.ToPILImage(),
  48. transforms.Resize(target_size),
  49. transforms.ToTensor(),
  50. transforms.Normalize((0.5), (0.5)) # min-max normalization to [-1, 1]
  51. ])
  52. data = ImageDataset(img_folder, transform=transform, labeler=labeler)
  53. return DataLoader(data, batch_size=batch_size, shuffle=shuffle)
  54. def model_output_to_image(y):
  55. y = 0.5 * (y + 1) # normalize back to [0, 1]
  56. y = y.clamp(0, 1) # clamp to [0, 1]
  57. y = y.view(y.size(0), 3, 256, 256)
  58. return y
  59. def get_log(name: str, display: bool = False, figsize: tuple = (12, 6)):
  60. its = []
  61. losses = []
  62. with open(f"./ae_train_NoBackup/{name}/log.csv", "r") as f:
  63. for line in f:
  64. it, loss = line.rstrip().split(",")
  65. its.append(int(it))
  66. losses.append(float(loss))
  67. if display:
  68. plt.figure(figsize=figsize)
  69. plt.plot(its, losses)
  70. plt.title(f"Training curve ({name})")
  71. plt.xlabel("Epoch")
  72. plt.ylabel("MSE Loss")
  73. plt.show()
  74. return its, losses