rbln.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. import logging
  2. import os
  3. from typing import List, Optional, Tuple
  4. from ray._private.accelerators.accelerator import AcceleratorManager
  5. from ray._private.ray_constants import env_bool
  6. logger = logging.getLogger(__name__)
  7. RBLN_RT_VISIBLE_DEVICES_ENV_VAR = "RBLN_DEVICES"
  8. NOSET_RBLN_RT_VISIBLE_DEVICES_ENV_VAR = "RAY_EXPERIMENTAL_NOSET_RBLN_RT_VISIBLE_DEVICES"
  9. class RBLNAcceleratorManager(AcceleratorManager):
  10. """Rebellions RBLN accelerators."""
  11. @staticmethod
  12. def get_resource_name() -> str:
  13. return "RBLN"
  14. @staticmethod
  15. def get_visible_accelerator_ids_env_var() -> str:
  16. return RBLN_RT_VISIBLE_DEVICES_ENV_VAR
  17. @staticmethod
  18. def get_current_process_visible_accelerator_ids() -> Optional[List[str]]:
  19. visible_devices = os.environ.get(
  20. RBLNAcceleratorManager.get_visible_accelerator_ids_env_var()
  21. )
  22. if visible_devices is None:
  23. return None
  24. if visible_devices == "":
  25. return []
  26. return visible_devices.split(",")
  27. @staticmethod
  28. def get_current_node_num_accelerators() -> int:
  29. """Detects the number of RBLN devices on the current machine."""
  30. try:
  31. from rebel import device_count
  32. return device_count()
  33. except Exception as e:
  34. logger.debug("Could not detect RBLN devices: %s", e)
  35. return 0
  36. @staticmethod
  37. def get_current_node_accelerator_type() -> Optional[str]:
  38. """Gets the type of RBLN NPU on the current node."""
  39. try:
  40. from rebel import get_npu_name
  41. return get_npu_name()
  42. except Exception as e:
  43. logger.exception("Failed to detect RBLN NPU type: %s", e)
  44. return None
  45. @staticmethod
  46. def validate_resource_request_quantity(
  47. quantity: float,
  48. ) -> Tuple[bool, Optional[str]]:
  49. if isinstance(quantity, float) and not quantity.is_integer():
  50. return (
  51. False,
  52. f"{RBLNAcceleratorManager.get_resource_name()} resource quantity"
  53. " must be whole numbers. "
  54. f"The specified quantity {quantity} is invalid.",
  55. )
  56. else:
  57. return (True, None)
  58. @staticmethod
  59. def set_current_process_visible_accelerator_ids(
  60. visible_rbln_devices: List[str],
  61. ) -> None:
  62. if env_bool(NOSET_RBLN_RT_VISIBLE_DEVICES_ENV_VAR, False):
  63. return
  64. os.environ[
  65. RBLNAcceleratorManager.get_visible_accelerator_ids_env_var()
  66. ] = ",".join(map(str, visible_rbln_devices))