tf_preprocessing.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. """ Tensorflow Preprocessing Adapter
  2. Allows use of Tensorflow preprocessing pipeline in PyTorch Transform
  3. Copyright of original Tensorflow code below.
  4. Hacked together by / Copyright 2020 Ross Wightman
  5. """
  6. # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
  7. #
  8. # Licensed under the Apache License, Version 2.0 (the "License");
  9. # you may not use this file except in compliance with the License.
  10. # You may obtain a copy of the License at
  11. #
  12. # http://www.apache.org/licenses/LICENSE-2.0
  13. #
  14. # Unless required by applicable law or agreed to in writing, software
  15. # distributed under the License is distributed on an "AS IS" BASIS,
  16. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. # See the License for the specific language governing permissions and
  18. # limitations under the License.
  19. # ==============================================================================
  20. """ImageNet preprocessing for MnasNet."""
  21. import tensorflow.compat.v1 as tf
  22. import numpy as np
  23. IMAGE_SIZE = 224
  24. CROP_PADDING = 32
  25. tf.compat.v1.disable_eager_execution()
  26. def distorted_bounding_box_crop(image_bytes,
  27. bbox,
  28. min_object_covered=0.1,
  29. aspect_ratio_range=(0.75, 1.33),
  30. area_range=(0.05, 1.0),
  31. max_attempts=100,
  32. scope=None):
  33. """Generates cropped_image using one of the bboxes randomly distorted.
  34. See `tf.image.sample_distorted_bounding_box` for more documentation.
  35. Args:
  36. image_bytes: `Tensor` of binary image data.
  37. bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]`
  38. where each coordinate is [0, 1) and the coordinates are arranged
  39. as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole
  40. image.
  41. min_object_covered: An optional `float`. Defaults to `0.1`. The cropped
  42. area of the image must contain at least this fraction of any bounding
  43. box supplied.
  44. aspect_ratio_range: An optional list of `float`s. The cropped area of the
  45. image must have an aspect ratio = width / height within this range.
  46. area_range: An optional list of `float`s. The cropped area of the image
  47. must contain a fraction of the supplied image within in this range.
  48. max_attempts: An optional `int`. Number of attempts at generating a cropped
  49. region of the image of the specified constraints. After `max_attempts`
  50. failures, return the entire image.
  51. scope: Optional `str` for name scope.
  52. Returns:
  53. cropped image `Tensor`
  54. """
  55. with tf.name_scope(scope, 'distorted_bounding_box_crop', [image_bytes, bbox]):
  56. shape = tf.image.extract_jpeg_shape(image_bytes)
  57. sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
  58. shape,
  59. bounding_boxes=bbox,
  60. min_object_covered=min_object_covered,
  61. aspect_ratio_range=aspect_ratio_range,
  62. area_range=area_range,
  63. max_attempts=max_attempts,
  64. use_image_if_no_bounding_boxes=True)
  65. bbox_begin, bbox_size, _ = sample_distorted_bounding_box
  66. # Crop the image to the specified bounding box.
  67. offset_y, offset_x, _ = tf.unstack(bbox_begin)
  68. target_height, target_width, _ = tf.unstack(bbox_size)
  69. crop_window = tf.stack([offset_y, offset_x, target_height, target_width])
  70. image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
  71. return image
  72. def _at_least_x_are_equal(a, b, x):
  73. """At least `x` of `a` and `b` `Tensors` are equal."""
  74. match = tf.equal(a, b)
  75. match = tf.cast(match, tf.int32)
  76. return tf.greater_equal(tf.reduce_sum(match), x)
  77. def _decode_and_random_crop(image_bytes, image_size, resize_method):
  78. """Make a random crop of image_size."""
  79. bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
  80. image = distorted_bounding_box_crop(
  81. image_bytes,
  82. bbox,
  83. min_object_covered=0.1,
  84. aspect_ratio_range=(3. / 4, 4. / 3.),
  85. area_range=(0.08, 1.0),
  86. max_attempts=10,
  87. scope=None)
  88. original_shape = tf.image.extract_jpeg_shape(image_bytes)
  89. bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3)
  90. image = tf.cond(
  91. bad,
  92. lambda: _decode_and_center_crop(image_bytes, image_size),
  93. lambda: tf.image.resize([image], [image_size, image_size], resize_method)[0])
  94. return image
  95. def _decode_and_center_crop(image_bytes, image_size, resize_method):
  96. """Crops to center of image with padding then scales image_size."""
  97. shape = tf.image.extract_jpeg_shape(image_bytes)
  98. image_height = shape[0]
  99. image_width = shape[1]
  100. padded_center_crop_size = tf.cast(
  101. ((image_size / (image_size + CROP_PADDING)) *
  102. tf.cast(tf.minimum(image_height, image_width), tf.float32)),
  103. tf.int32)
  104. offset_height = ((image_height - padded_center_crop_size) + 1) // 2
  105. offset_width = ((image_width - padded_center_crop_size) + 1) // 2
  106. crop_window = tf.stack([offset_height, offset_width,
  107. padded_center_crop_size, padded_center_crop_size])
  108. image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
  109. image = tf.image.resize([image], [image_size, image_size], resize_method)[0]
  110. return image
  111. def _flip(image):
  112. """Random horizontal image flip."""
  113. image = tf.image.random_flip_left_right(image)
  114. return image
  115. def preprocess_for_train(image_bytes, use_bfloat16, image_size=IMAGE_SIZE, interpolation='bicubic'):
  116. """Preprocesses the given image for evaluation.
  117. Args:
  118. image_bytes: `Tensor` representing an image binary of arbitrary size.
  119. use_bfloat16: `bool` for whether to use bfloat16.
  120. image_size: image size.
  121. interpolation: image interpolation method
  122. Returns:
  123. A preprocessed image `Tensor`.
  124. """
  125. resize_method = tf.image.ResizeMethod.BICUBIC if interpolation == 'bicubic' else tf.image.ResizeMethod.BILINEAR
  126. image = _decode_and_random_crop(image_bytes, image_size, resize_method)
  127. image = _flip(image)
  128. image = tf.reshape(image, [image_size, image_size, 3])
  129. image = tf.image.convert_image_dtype(
  130. image, dtype=tf.bfloat16 if use_bfloat16 else tf.float32)
  131. return image
  132. def preprocess_for_eval(image_bytes, use_bfloat16, image_size=IMAGE_SIZE, interpolation='bicubic'):
  133. """Preprocesses the given image for evaluation.
  134. Args:
  135. image_bytes: `Tensor` representing an image binary of arbitrary size.
  136. use_bfloat16: `bool` for whether to use bfloat16.
  137. image_size: image size.
  138. interpolation: image interpolation method
  139. Returns:
  140. A preprocessed image `Tensor`.
  141. """
  142. resize_method = tf.image.ResizeMethod.BICUBIC if interpolation == 'bicubic' else tf.image.ResizeMethod.BILINEAR
  143. image = _decode_and_center_crop(image_bytes, image_size, resize_method)
  144. image = tf.reshape(image, [image_size, image_size, 3])
  145. image = tf.image.convert_image_dtype(
  146. image, dtype=tf.bfloat16 if use_bfloat16 else tf.float32)
  147. return image
  148. def preprocess_image(image_bytes,
  149. is_training=False,
  150. use_bfloat16=False,
  151. image_size=IMAGE_SIZE,
  152. interpolation='bicubic'):
  153. """Preprocesses the given image.
  154. Args:
  155. image_bytes: `Tensor` representing an image binary of arbitrary size.
  156. is_training: `bool` for whether the preprocessing is for training.
  157. use_bfloat16: `bool` for whether to use bfloat16.
  158. image_size: image size.
  159. interpolation: image interpolation method
  160. Returns:
  161. A preprocessed image `Tensor` with value range of [0, 255].
  162. """
  163. if is_training:
  164. return preprocess_for_train(image_bytes, use_bfloat16, image_size, interpolation)
  165. else:
  166. return preprocess_for_eval(image_bytes, use_bfloat16, image_size, interpolation)
  167. class TfPreprocessTransform:
  168. def __init__(self, is_training=False, size=224, interpolation='bicubic'):
  169. self.is_training = is_training
  170. self.size = size[0] if isinstance(size, tuple) else size
  171. self.interpolation = interpolation
  172. self._image_bytes = None
  173. self.process_image = self._build_tf_graph()
  174. self.sess = None
  175. def _build_tf_graph(self):
  176. with tf.device('/cpu:0'):
  177. self._image_bytes = tf.placeholder(
  178. shape=[],
  179. dtype=tf.string,
  180. )
  181. img = preprocess_image(
  182. self._image_bytes, self.is_training, False, self.size, self.interpolation)
  183. return img
  184. def __call__(self, image_bytes):
  185. if self.sess is None:
  186. self.sess = tf.Session()
  187. img = self.sess.run(self.process_image, feed_dict={self._image_bytes: image_bytes})
  188. img = img.round().clip(0, 255).astype(np.uint8)
  189. if img.ndim < 3:
  190. img = np.expand_dims(img, axis=-1)
  191. img = np.rollaxis(img, 2) # HWC to CHW
  192. return img