images.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import importlib
  2. import logging
  3. import numpy as np
  4. from ray.rllib.utils.annotations import DeveloperAPI
  5. logger = logging.getLogger(__name__)
  6. @DeveloperAPI
  7. def is_package_installed(package_name):
  8. try:
  9. importlib.metadata.version(package_name)
  10. return True
  11. except importlib.metadata.PackageNotFoundError:
  12. return False
  13. try:
  14. import cv2
  15. cv2.ocl.setUseOpenCL(False)
  16. logger.debug("CV2 found for image processing.")
  17. except ImportError as e:
  18. if is_package_installed("opencv-python"):
  19. raise ImportError(
  20. f"OpenCV is installed, but we failed to import it. This may be because "
  21. f"you need to install `opencv-python-headless` instead of "
  22. f"`opencv-python`. Error message: {e}",
  23. )
  24. cv2 = None
  25. @DeveloperAPI
  26. def resize(img: np.ndarray, height: int, width: int) -> np.ndarray:
  27. if not cv2:
  28. raise ModuleNotFoundError(
  29. "`opencv` not installed! Do `pip install opencv-python`"
  30. )
  31. return cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
  32. @DeveloperAPI
  33. def rgb2gray(img: np.ndarray) -> np.ndarray:
  34. if not cv2:
  35. raise ModuleNotFoundError(
  36. "`opencv` not installed! Do `pip install opencv-python`"
  37. )
  38. return cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
  39. @DeveloperAPI
  40. def imread(img_file: str) -> np.ndarray:
  41. if not cv2:
  42. raise ModuleNotFoundError(
  43. "`opencv` not installed! Do `pip install opencv-python`"
  44. )
  45. return cv2.imread(img_file).astype(np.float32)