PyTorchData.py 3.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import os
  2. import matplotlib.pyplot as plt
  3. from torchvision import io, transforms
  4. from torch.utils.data import DataLoader, Dataset
  5. class ImageDataset(Dataset):
  6. def __init__(self, img_dir: str, transform = None, labeler = None, filter = lambda filename: True):
  7. self.img_dir = img_dir
  8. self.transform = transform
  9. self.labeler = labeler
  10. with os.scandir(img_dir) as it:
  11. self.files = [entry.name for entry in it if entry.name.endswith(".jpg") and entry.is_file() and filter(entry.name)]
  12. print(f"{len(self.files)} files found")
  13. def __len__(self):
  14. return len(self.files)
  15. def __getitem__(self, idx):
  16. img_path = os.path.join(self.img_dir, self.files[idx])
  17. img = io.read_image(img_path)
  18. if self.transform:
  19. img = self.transform(img)
  20. label = 0
  21. if self.labeler:
  22. label = self.labeler(self.files[idx])
  23. return img, label
  24. def create_dataloader(img_folder: str, target_size: tuple = (256, 256), batch_size: int = 32, shuffle: bool = True, truncate_y: tuple = (40, 40), labeler = None, skip_transforms: bool = False, filter = lambda filename: True) -> DataLoader:
  25. """Creates a PyTorch DataLoader from the given image folder.
  26. Args:
  27. img_folder (str): Folder containing images. (All subfolders will be scanned for jpg images)
  28. target_size (tuple, optional): Model input size. Images are resized to this size. Defaults to (256, 256).
  29. batch_size (int, optional): Batch size. Defaults to 32.
  30. shuffle (bool, optional): Shuffle images. Good for training, useless for testing. Defaults to True.
  31. truncate_y (tuple, optional): (a, b), cut off the first a and the last b pixel rows of the unresized image. Defaults to (40, 40).
  32. labeler (lambda(filename: str) -> int, optional): Lambda that maps every filename to an int label. By default all labels are 0. Defaults to None.
  33. skip_transforms (bool, optional): Skip truncate and resize transforms. (If the images are already truncated and resized). Defaults to False.
  34. filter (lambda: str -> bool, optional): Additional filter by filename. Defaults to lambda filename: True.
  35. Returns:
  36. DataLoader: PyTorch DataLoader
  37. """
  38. def crop_lambda(img):
  39. return transforms.functional.crop(img, truncate_y[0], 0, img.shape[-2] - truncate_y[0] - truncate_y[1], img.shape[-1])
  40. transform = None
  41. if skip_transforms:
  42. transform = transforms.Compose([
  43. transforms.Lambda(lambda img: img.float()),
  44. transforms.Normalize((127.5), (127.5)) # min-max normalization to [-1, 1]
  45. ])
  46. else:
  47. transform = transforms.Compose([
  48. transforms.Lambda(crop_lambda),
  49. transforms.ToPILImage(),
  50. transforms.Resize(target_size),
  51. transforms.ToTensor(),
  52. transforms.Normalize((0.5), (0.5)) # min-max normalization to [-1, 1]
  53. ])
  54. data = ImageDataset(img_folder, transform=transform, labeler=labeler, filter=filter)
  55. return DataLoader(data, batch_size=batch_size, shuffle=shuffle)
  56. def model_output_to_image(y):
  57. y = 0.5 * (y + 1) # normalize back to [0, 1]
  58. y = y.clamp(0, 1) # clamp to [0, 1]
  59. y = y.view(y.size(0), 3, 256, 256)
  60. return y
  61. def get_log(name: str, display: bool = False, figsize: tuple = (12, 6)):
  62. its = []
  63. losses = []
  64. with open(f"./ae_train_NoBackup/{name}/log.csv", "r") as f:
  65. for line in f:
  66. it, loss = line.rstrip().split(",")[:2]
  67. its.append(int(it))
  68. losses.append(float(loss))
  69. if display:
  70. plt.figure(figsize=figsize)
  71. plt.plot(its, losses)
  72. plt.title(f"Training curve ({name})")
  73. plt.xlabel("Epoch")
  74. plt.ylabel("MSE Loss")
  75. plt.show()
  76. return its, losses