functions.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. import numpy as np
  2. from chainercv import transforms as tr
  3. from collections.abc import Iterable
  4. class GenericPrepare:
  5. def __init__(self, size,
  6. crop_fraction=0.875,
  7. swap_channels=True,
  8. zero_mean=False,
  9. keep_ratio=True):
  10. super().__init__()
  11. self.crop_fraction = crop_fraction
  12. self.swap_channels = swap_channels
  13. self.zero_mean = zero_mean
  14. self.keep_ratio = keep_ratio
  15. def __call__(self, im, size=None, *args, swap_channels=None, keep_ratio=None, zero_mean=None, **kwargs):
  16. size = self.size if size is None else size
  17. swap_channels = self.swap_channels if swap_channels is None else swap_channels
  18. keep_ratio = self.keep_ratio if keep_ratio is None else keep_ratio
  19. zero_mean = self.zero_mean if zero_mean is None else zero_mean
  20. crop_size = None
  21. h, w, c = im.shape
  22. _im = im.transpose(2, 0, 1)
  23. if self.swap_channels:
  24. # RGB -> BGR
  25. _im = _im[::-1]
  26. if self.crop_fraction:
  27. crop_size = (np.array([h, w]) * self.crop_fraction).astype(np.int32)
  28. _im = tr.center_crop(_im, crop_size)
  29. # bilinear interpolation
  30. if self.keep_ratio:
  31. if isinstance(size, tuple):
  32. size = size[0]
  33. _im = tr.scale(_im, size, interpolation=2)
  34. else:
  35. if isinstance(size, int):
  36. size = (size, size)
  37. _im = tr.resize(_im, size, interpolation=2)
  38. if _im.dtype == np.uint8:
  39. # rescale [0 .. 255] -> [0 .. 1]
  40. _im = (_im / 255).astype(np.float32)
  41. if self.zero_mean:
  42. # rescale [0 .. 1] -> [-1 .. 1]
  43. _im = _im * 2 - 1
  44. return _im
  45. class GenericTFPrepare:
  46. def __init__(self, size, crop_fraction, from_path):
  47. super().__init__()
  48. import tensorflow as tf
  49. config_sess = tf.ConfigProto(allow_soft_placement=True)
  50. config_sess.gpu_options.allow_growth = True
  51. self.sess = tf.Session(config=config_sess)
  52. self.from_path = from_path
  53. if from_path:
  54. self.im_input = im_input = tf.placeholder(tf.string)
  55. image = tf.image.decode_jpeg(tf.read_file(im_input), channels=3)
  56. image = tf.image.convert_image_dtype(image, tf.float32)
  57. else:
  58. self.im_input = image = im_input = tf.placeholder(tf.float32, shape=(None, None, 3))
  59. raise NotImplementedError("REFACTOR ME!")
  60. image = tf.image.central_crop(image, central_fraction=crop_fraction)
  61. image = tf.expand_dims(image, 0)
  62. image = tf.image.resize_bilinear(image, [size, size], align_corners=False)
  63. image = tf.squeeze(image, [0])
  64. image = tf.subtract(image, 0.5)
  65. self.output = tf.multiply(image, 2)
  66. def __call__(self, im, *args, **kwargs):
  67. if not self.from_path and im.dtype == np.uint8:
  68. im = im / 255
  69. res = self.sess.run(self.output, feed_dict={self.im_input: im})
  70. return res.transpose(2, 0, 1)
  71. class ChainerCV2Prepare:
  72. def __init__(self, size, *, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
  73. super().__init__()
  74. self.size = size
  75. self.mean = np.array(mean, dtype=np.float32).reshape(-1, 1, 1)
  76. self.std = np.array(std, dtype=np.float32).reshape(-1, 1, 1)
  77. def _size(self, size):
  78. size = self.size if size is None else size
  79. if isinstance(size, Iterable):
  80. size = min(size)
  81. return size
  82. def __call__(self, im, size=None, *args, **kwargs):
  83. _im = im.transpose(2, 0, 1)
  84. _im = _im.astype(np.float32) / 255.0
  85. _im = tr.scale(_im, self._size(size), interpolation=2)
  86. _im -= self.mean
  87. _im /= self.std
  88. return _im