extension.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. import os
  2. import torch
  3. from ._internally_replaced_utils import _get_extension_path
  4. def _load_library(lib_name):
  5. """Load a library, optionally warning on failure based on env variable.
  6. Returns True if the library was loaded successfully, False otherwise.
  7. """
  8. try:
  9. lib_path = _get_extension_path(lib_name)
  10. torch.ops.load_library(lib_path)
  11. return True
  12. except (ImportError, OSError) as e:
  13. if os.environ.get("TORCHVISION_WARN_WHEN_EXTENSION_LOADING_FAILS"):
  14. import warnings
  15. warnings.warn(f"Failed to load '{lib_name}' extension: {type(e).__name__}: {e}")
  16. return False
  17. def _has_ops():
  18. return False
  19. if _load_library("_C"):
  20. def _has_ops(): # noqa: F811
  21. return True
  22. def _assert_has_ops():
  23. if not _has_ops():
  24. raise RuntimeError(
  25. "Couldn't load custom C++ ops. This can happen if your PyTorch and "
  26. "torchvision versions are incompatible, or if you had errors while compiling "
  27. "torchvision from source. For further information on the compatible versions, check "
  28. "https://github.com/pytorch/vision#installation for the compatibility matrix. "
  29. "Please check your PyTorch version with torch.__version__ and your torchvision "
  30. "version with torchvision.__version__ and verify if they are compatible, and if not "
  31. "please reinstall torchvision so that it matches your PyTorch install. "
  32. "Set TORCHVISION_WARN_WHEN_EXTENSION_LOADING_FAILS=1 and retry to get more details."
  33. )
  34. def _check_cuda_version():
  35. """
  36. Make sure that CUDA versions match between the pytorch install and torchvision install
  37. """
  38. if not _has_ops():
  39. return -1
  40. from torch.version import cuda as torch_version_cuda
  41. _version = torch.ops.torchvision._cuda_version()
  42. if _version != -1 and torch_version_cuda is not None:
  43. tv_version = str(_version)
  44. assert int(tv_version) >= 12000, f"Unexpected CUDA version {_version}, please file a bug report."
  45. tv_major = int(tv_version[0:2])
  46. tv_minor = int(tv_version[3])
  47. t_version = torch_version_cuda.split(".")
  48. t_major = int(t_version[0])
  49. t_minor = int(t_version[1])
  50. if t_major != tv_major:
  51. raise RuntimeError(
  52. "Detected that PyTorch and torchvision were compiled with different CUDA major versions. "
  53. f"PyTorch has CUDA Version={t_major}.{t_minor} and torchvision has "
  54. f"CUDA Version={tv_major}.{tv_minor}. "
  55. "Please reinstall the torchvision that matches your PyTorch install."
  56. )
  57. return _version
  58. _check_cuda_version()