onnxruntime_validation.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. """
  6. Check OS requirements for ONNX Runtime Python Bindings.
  7. """
  8. import linecache
  9. import platform
  10. import warnings
  11. def check_distro_info():
  12. __my_distro__ = ""
  13. __my_distro_ver__ = ""
  14. __my_system__ = platform.system().lower()
  15. __OS_RELEASE_FILE__ = "/etc/os-release" # noqa: N806
  16. __LSB_RELEASE_FILE__ = "/etc/lsb-release" # noqa: N806
  17. if __my_system__ == "windows":
  18. __my_distro__ = __my_system__
  19. __my_distro_ver__ = platform.release().lower()
  20. if __my_distro_ver__ not in ["10", "11", "2016server", "2019server", "2022server", "2025server"]:
  21. warnings.warn(
  22. f"Unsupported Windows version ({__my_distro_ver__}). ONNX Runtime supports Windows 10 and above, or Windows Server 2016 and above."
  23. )
  24. elif __my_system__ == "linux":
  25. """Although the 'platform' python module for getting Distro information works well on standard OS images
  26. running on real hardware, it is not accurate when running on Azure VMs, Git Bash, Cygwin, etc.
  27. The returned values for release and version are unpredictable for virtualized or emulated environments.
  28. /etc/os-release and /etc/lsb_release files, on the other hand, are guaranteed to exist and have standard values
  29. in all OSes supported by onnxruntime. The former is the current standard file to check OS info and the latter
  30. is its predecessor.
  31. """
  32. # Newer systems have /etc/os-release with relevant distro info
  33. __my_distro__ = linecache.getline(__OS_RELEASE_FILE__, 3)[3:-1]
  34. __my_distro_ver__ = linecache.getline(__OS_RELEASE_FILE__, 6)[12:-2]
  35. # Older systems may have /etc/os-release instead
  36. if not __my_distro__:
  37. __my_distro__ = linecache.getline(__LSB_RELEASE_FILE__, 1)[11:-1]
  38. __my_distro_ver__ = linecache.getline(__LSB_RELEASE_FILE__, 2)[16:-1]
  39. # Instead of trying to parse distro specific files,
  40. # warn the user ONNX Runtime may not work out of the box
  41. __my_distro__ = __my_distro__.lower()
  42. __my_distro_ver__ = __my_distro_ver__.lower()
  43. elif __my_system__ == "darwin":
  44. __my_distro__ = __my_system__
  45. __my_distro_ver__ = platform.release().lower()
  46. if int(__my_distro_ver__.split(".")[0]) < 11:
  47. warnings.warn(
  48. f"Unsupported macOS version ({__my_distro_ver__}). ONNX Runtime supports macOS 11.0 or later."
  49. )
  50. elif __my_system__ == "aix":
  51. import subprocess # noqa: PLC0415
  52. returned_output = subprocess.check_output("oslevel")
  53. __my_distro_ver__str = returned_output.decode("utf-8")
  54. __my_distro_ver = __my_distro_ver__str[:3]
  55. else:
  56. warnings.warn(
  57. f"Unsupported platform ({__my_system__}). ONNX Runtime supports Linux, macOS, AIX and Windows platforms, only."
  58. )
  59. def get_package_name_and_version_info():
  60. package_name = ""
  61. version = ""
  62. cuda_version = ""
  63. try:
  64. from .build_and_package_info import __version__ as version # noqa: PLC0415
  65. from .build_and_package_info import package_name # noqa: PLC0415
  66. try: # noqa: SIM105
  67. from .build_and_package_info import cuda_version # noqa: PLC0415
  68. except ImportError:
  69. # cuda_version is optional. For example, cpu only package does not have the attribute.
  70. pass
  71. except Exception as e:
  72. warnings.warn("WARNING: failed to collect package name and version info")
  73. print(e)
  74. return package_name, version, cuda_version
  75. def check_training_module():
  76. import_ortmodule_exception = None
  77. has_ortmodule = False
  78. try:
  79. from onnxruntime.training.ortmodule import ORTModule # noqa: F401, PLC0415
  80. has_ortmodule = True
  81. except ImportError:
  82. # ORTModule not present
  83. has_ortmodule = False
  84. except Exception as e:
  85. # this may happen if Cuda is not installed, we want to raise it after
  86. # for any exception other than not having ortmodule, we want to continue
  87. # device version validation and raise the exception after.
  88. try:
  89. from onnxruntime.training.ortmodule._fallback import ORTModuleInitException # noqa: PLC0415
  90. if isinstance(e, ORTModuleInitException):
  91. # ORTModule is present but not ready to run yet
  92. has_ortmodule = True
  93. except Exception:
  94. # ORTModule not present
  95. has_ortmodule = False
  96. if not has_ortmodule:
  97. import_ortmodule_exception = e
  98. # collect onnxruntime package name, version, and cuda version
  99. package_name, version, cuda_version = get_package_name_and_version_info()
  100. if has_ortmodule and cuda_version:
  101. try:
  102. # collect cuda library build info. the library info may not be available
  103. # when the build environment has none or multiple libraries installed
  104. try:
  105. from .build_and_package_info import cudart_version # noqa: PLC0415
  106. except ImportError:
  107. warnings.warn("WARNING: failed to get cudart_version from onnxruntime build info.")
  108. cudart_version = None
  109. def print_build_package_info():
  110. warnings.warn(f"onnxruntime training package info: package_name: {package_name}")
  111. warnings.warn(f"onnxruntime training package info: __version__: {version}")
  112. warnings.warn(f"onnxruntime training package info: cuda_version: {cuda_version}")
  113. warnings.warn(f"onnxruntime build info: cudart_version: {cudart_version}")
  114. # collection cuda library info from current environment.
  115. from onnxruntime.capi.onnxruntime_collect_build_info import find_cudart_versions # noqa: PLC0415
  116. local_cudart_versions = find_cudart_versions(build_env=False, build_cuda_version=cuda_version)
  117. if cudart_version and local_cudart_versions and cudart_version not in local_cudart_versions:
  118. print_build_package_info()
  119. warnings.warn("WARNING: failed to find cudart version that matches onnxruntime build info")
  120. warnings.warn(f"WARNING: found cudart versions: {local_cudart_versions}")
  121. except Exception as e:
  122. warnings.warn("WARNING: failed to collect onnxruntime version and build info")
  123. print(e)
  124. if import_ortmodule_exception:
  125. raise import_ortmodule_exception
  126. return has_ortmodule, package_name, version, cuda_version