torch.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. from typing import TYPE_CHECKING, Callable, Dict, List, Mapping, Optional, Union
  2. import numpy as np
  3. from ray.data._internal.tensor_extensions.utils import _create_possibly_ragged_ndarray
  4. from ray.data.preprocessor import Preprocessor
  5. from ray.data.util.data_batch_conversion import BatchFormat
  6. from ray.util.annotations import PublicAPI
  7. if TYPE_CHECKING:
  8. import torch
  9. @PublicAPI(stability="alpha")
  10. class TorchVisionPreprocessor(Preprocessor):
  11. """Apply a `TorchVision transform <https://pytorch.org/vision/stable/transforms.html>`_
  12. to image columns.
  13. Examples:
  14. Torch models expect inputs of shape :math:`(B, C, H, W)` in the range
  15. :math:`[0.0, 1.0]`. To convert images to this format, add ``ToTensor`` to your
  16. preprocessing pipeline.
  17. .. testcode::
  18. from torchvision import transforms
  19. import ray
  20. from ray.data.preprocessors import TorchVisionPreprocessor
  21. transform = transforms.Compose([
  22. transforms.ToTensor(),
  23. transforms.Resize((224, 224)),
  24. ])
  25. preprocessor = TorchVisionPreprocessor(["image"], transform=transform)
  26. dataset = ray.data.read_images("s3://anonymous@air-example-data-2/imagenet-sample-images")
  27. dataset = preprocessor.transform(dataset)
  28. For better performance, set ``batched`` to ``True`` and replace ``ToTensor``
  29. with a batch-supporting ``Lambda``.
  30. .. testcode::
  31. import numpy as np
  32. import torch
  33. def to_tensor(batch: np.ndarray) -> torch.Tensor:
  34. tensor = torch.as_tensor(batch, dtype=torch.float)
  35. # (B, H, W, C) -> (B, C, H, W)
  36. tensor = tensor.permute(0, 3, 1, 2).contiguous()
  37. # [0., 255.] -> [0., 1.]
  38. tensor = tensor.div(255)
  39. return tensor
  40. transform = transforms.Compose([
  41. transforms.Lambda(to_tensor),
  42. transforms.Resize((224, 224))
  43. ])
  44. preprocessor = TorchVisionPreprocessor(["image"], transform=transform, batched=True)
  45. dataset = ray.data.read_images("s3://anonymous@air-example-data-2/imagenet-sample-images")
  46. dataset = preprocessor.transform(dataset)
  47. Args:
  48. columns: The columns to apply the TorchVision transform to.
  49. transform: The TorchVision transform you want to apply. This transform should
  50. accept a ``np.ndarray`` or ``torch.Tensor`` as input and return a
  51. ``torch.Tensor`` as output.
  52. output_columns: The output name for each input column. If not specified, this
  53. defaults to the same set of columns as the columns.
  54. batched: If ``True``, apply ``transform`` to batches of shape
  55. :math:`(B, H, W, C)`. Otherwise, apply ``transform`` to individual images.
  56. """ # noqa: E501
  57. _is_fittable = False
  58. def __init__(
  59. self,
  60. columns: List[str],
  61. transform: Callable[[Union["np.ndarray", "torch.Tensor"]], "torch.Tensor"],
  62. output_columns: Optional[List[str]] = None,
  63. batched: bool = False,
  64. ):
  65. super().__init__()
  66. if not output_columns:
  67. output_columns = columns
  68. if len(columns) != len(output_columns):
  69. raise ValueError(
  70. "The length of columns should match the "
  71. f"length of output_columns: {columns} vs {output_columns}."
  72. )
  73. self._columns = columns
  74. self._output_columns = output_columns
  75. self._torchvision_transform = transform
  76. self._batched = batched
  77. def __repr__(self) -> str:
  78. return (
  79. f"{self.__class__.__name__}("
  80. f"columns={self._columns}, "
  81. f"output_columns={self._output_columns}, "
  82. f"transform={self._torchvision_transform!r})"
  83. )
  84. def _transform_numpy(
  85. self, data_batch: Dict[str, "np.ndarray"]
  86. ) -> Dict[str, "np.ndarray"]:
  87. import torch
  88. from ray.data.util.torch_utils import convert_ndarray_to_torch_tensor
  89. def apply_torchvision_transform(array: np.ndarray) -> np.ndarray:
  90. try:
  91. tensor = convert_ndarray_to_torch_tensor(array)
  92. output = self._torchvision_transform(tensor)
  93. except TypeError:
  94. # Transforms like `ToTensor` expect a `np.ndarray` as input.
  95. output = self._torchvision_transform(array)
  96. if isinstance(output, torch.Tensor):
  97. output = output.numpy()
  98. if not isinstance(output, np.ndarray):
  99. raise ValueError(
  100. "`TorchVisionPreprocessor` expected your transform to return a "
  101. "`torch.Tensor` or `np.ndarray`, but your transform returned a "
  102. f"`{type(output).__name__}` instead."
  103. )
  104. return output
  105. def transform_batch(batch: np.ndarray) -> np.ndarray:
  106. if self._batched:
  107. return apply_torchvision_transform(batch)
  108. return _create_possibly_ragged_ndarray(
  109. [apply_torchvision_transform(array) for array in batch]
  110. )
  111. if isinstance(data_batch, Mapping):
  112. for input_col, output_col in zip(self._columns, self._output_columns):
  113. data_batch[output_col] = transform_batch(data_batch[input_col])
  114. else:
  115. # TODO(ekl) deprecate this code path. Unfortunately, predictors are still
  116. # sending schemaless arrays to preprocessors.
  117. data_batch = transform_batch(data_batch)
  118. return data_batch
  119. def get_input_columns(self) -> List[str]:
  120. return self._columns
  121. def get_output_columns(self) -> List[str]:
  122. return self._output_columns
  123. def preferred_batch_format(cls) -> BatchFormat:
  124. return BatchFormat.NUMPY