resize_images.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. #!/usr/bin/env python
  2. if __name__ != '__main__': raise Exception("Do not import me!")
  3. import logging
  4. import multiprocessing as mp
  5. import os
  6. import shutil
  7. from PIL import Image
  8. from cvargparse import BaseParser
  9. from tqdm.auto import tqdm
  10. from cvargparse import cvdataclass
  11. from functools import partial
  12. from pathlib import Path
  13. @cvdataclass
  14. class Args:
  15. source: str = "source"
  16. destination: str = "dest"
  17. n_jobs: int = -1
  18. extensions: tuple = (".jpg", ".jpeg", ".png")
  19. size: int = 1000
  20. fit_short: bool = False
  21. remove_existing: bool = False
  22. def resize(name, *, source: Path, dest: Path, size: int, fit_short: bool = False):
  23. src = source / name
  24. dst = dest / name
  25. with Image.open(src) as im:
  26. w, h = im.size
  27. if fit_short:
  28. if w > h:
  29. W, H = (int(size * w / h), size)
  30. else:
  31. W, H = (size, int(size * h / w))
  32. else:
  33. if w > h:
  34. W, H = (size, int(size * h / w))
  35. else:
  36. W, H = (int(size * w / h), size)
  37. dst.parent.mkdir(parents=True, exist_ok=True)
  38. im.resize((W, H)).save(dst)
  39. def main(args: Args):
  40. source = Path(args.source)
  41. destination = Path(args.destination).resolve()
  42. assert source.exists(), \
  43. f"\"{source.resolve()}\" does not exist!"
  44. if destination.exists():
  45. if args.remove_existing:
  46. shutil.rmtree(destination)
  47. else:
  48. raise ValueError(f"\"{destination}\" exists, but --remove_existsing was not set!")
  49. logging.info(f"resized images will be written to \"{destination}\"")
  50. destination.mkdir(parents=True, exist_ok=True)
  51. images = []
  52. for root, dirs, files in os.walk(source):
  53. for file in files:
  54. if Path(file).suffix.lower() not in args.extensions:
  55. continue
  56. images.append(Path(root, file).relative_to(source))
  57. logging.info(f"Found {len(images)} images in \"{source}\"")
  58. work = partial(resize,
  59. source=source,
  60. dest=destination,
  61. size=args.size,
  62. fit_short=args.fit_short,
  63. )
  64. if args.n_jobs >= 1:
  65. with mp.Pool(args.n_jobs) as pool:
  66. runner = pool.imap(work, images)
  67. for i in tqdm(runner, total=len(images)):
  68. pass
  69. else:
  70. for imname in tqdm(images):
  71. work(imname)
  72. args = BaseParser(Args).parse_args()
  73. main(args)