_pallas.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. import functools
  2. import torch
  3. @functools.cache
  4. def has_jax_package() -> bool:
  5. """Check if JAX is installed."""
  6. try:
  7. import jax # noqa: F401 # type: ignore[import-not-found]
  8. return True
  9. except ImportError:
  10. return False
  11. @functools.cache
  12. def has_pallas_package() -> bool:
  13. """Check if Pallas (JAX experimental) is available."""
  14. if not has_jax_package():
  15. return False
  16. try:
  17. from jax.experimental import ( # noqa: F401 # type: ignore[import-not-found]
  18. pallas as pl,
  19. )
  20. return True
  21. except ImportError:
  22. return False
  23. @functools.cache
  24. def get_jax_version(fallback: tuple[int, int, int] = (0, 0, 0)) -> tuple[int, int, int]:
  25. """Get JAX version as (major, minor, patch) tuple."""
  26. try:
  27. import jax # type: ignore[import-not-found]
  28. version_parts = jax.__version__.split(".")
  29. major, minor, patch = (int(v) for v in version_parts[:3])
  30. return (major, minor, patch)
  31. except (ImportError, ValueError, AttributeError):
  32. return fallback
  33. @functools.cache
  34. def has_jax_cuda_backend() -> bool:
  35. """Check if JAX has CUDA backend support with SM90+ (required by Mosaic GPU)."""
  36. if not has_jax_package():
  37. return False
  38. try:
  39. import jax # type: ignore[import-not-found]
  40. # Check if CUDA backend is available
  41. devices = jax.devices("gpu")
  42. if len(devices) == 0:
  43. return False
  44. # Mosaic GPU requires SM90+ (compute capability 9.0+)
  45. if torch.cuda.is_available():
  46. major, minor = torch.cuda.get_device_capability()
  47. if major < 9:
  48. return False
  49. return True
  50. except Exception:
  51. return False
  52. @functools.cache
  53. def has_jax_tpu_backend() -> bool:
  54. """Check if JAX has TPU backend support."""
  55. if not has_jax_package():
  56. return False
  57. try:
  58. import jax # type: ignore[import-not-found]
  59. # Check if TPU backend is available
  60. devices = jax.devices("tpu")
  61. return len(devices) > 0
  62. except Exception:
  63. return False
  64. @functools.cache
  65. def has_torch_tpu() -> bool:
  66. """Check if torch_tpu is available."""
  67. try:
  68. import torch_tpu # noqa: F401 # type: ignore[import-not-found]
  69. return True
  70. except ImportError:
  71. return False
  72. @functools.cache
  73. def has_cpu_pallas() -> bool:
  74. """Checks for a full Pallas-on-CPU environment."""
  75. return has_pallas_package()
  76. @functools.cache
  77. def has_cuda_pallas() -> bool:
  78. """Checks for a full Pallas-on-CUDA environment."""
  79. return has_pallas_package() and torch.cuda.is_available() and has_jax_cuda_backend()
  80. @functools.cache
  81. def has_tpu_pallas() -> bool:
  82. """Checks for a full Pallas-on-TPU environment."""
  83. return has_pallas_package() and has_torch_tpu()
  84. @functools.cache
  85. def has_pallas() -> bool:
  86. """
  87. Check if Pallas backend is fully available for use.
  88. Requirements:
  89. - JAX package installed
  90. - Pallas (jax.experimental.pallas) available
  91. - A compatible backend (CUDA or TPU) is available in both PyTorch and JAX.
  92. """
  93. return has_cpu_pallas() or has_cuda_pallas() or has_tpu_pallas()