_device_limits.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. import torch
  2. from torch._C import dtype
  3. __all__ = ["GPULimits"]
  4. class GPULimits:
  5. r"""Utility class that provides the theoretical limits of Nvidia GPU devices. The
  6. limits don't take into account thermal throttling (assume that the GPU run at its
  7. peak rated frequency). This is because user hardware configuration may influence
  8. power behavior.
  9. """
  10. def __init__(self, target_device: torch.device):
  11. # The device properties object is obtained by calling 'cudaGetDeviceProperties' CUDA
  12. # runtime function. We need the total memory bus width and the memory clock rate to
  13. # calculate the memory bandwidth.
  14. self.device_properties = torch.cuda.get_device_properties(target_device)
  15. # The compute capability is needed to determine the number of FLOPs per cycle per SM
  16. self.compute_capability = int(
  17. f"{self.device_properties.major}{self.device_properties.minor}"
  18. )
  19. # FLOPs per cycle information derived from Table 2 in:
  20. # https://resources.nvidia.com/en-us-hopper-architecture/nvidia-h100-tensor-c
  21. # Returns the number of FMA instructions retired per cycle per SM for a given
  22. # data type, when tensor cores are NOT used
  23. def get_fma_per_cycle_per_sm_cuda_cores(self, data_type: dtype) -> int:
  24. hardcoded_device_values = {
  25. # Ampere Architecture
  26. "fp16_80": 256,
  27. "fp32_80": 64,
  28. "fp64_80": 32,
  29. # Hopper Architecture
  30. "fp16_90": 64,
  31. "fp32_90": 128,
  32. "fp64_90": 64,
  33. # Blackwell Architecture
  34. "fp16_100": 256,
  35. "fp32_100": 128,
  36. "fp64_100": 64,
  37. }
  38. dict_key = ""
  39. if data_type is torch.float16:
  40. dict_key = f"fp16_{self.compute_capability}"
  41. elif data_type is torch.float32:
  42. dict_key = f"fp32_{self.compute_capability}"
  43. elif data_type is torch.float64:
  44. dict_key = f"fp64_{self.compute_capability}"
  45. else:
  46. dict_key = "unknown"
  47. if dict_key not in hardcoded_device_values:
  48. raise RuntimeError(
  49. f"No data for sm_{self.compute_capability} and {data_type}."
  50. )
  51. return hardcoded_device_values[dict_key]
  52. # Returns the number of FMA instructions retired per cycle per SM for a given
  53. # data type, when tensor cores ARE used
  54. def get_fma_per_cycle_per_sm_tensor_cores(self, data_type: dtype) -> int:
  55. hardcoded_device_values = {
  56. # Ampere Architecture
  57. "int8_80": 2048,
  58. "fp16_80": 1024,
  59. "fp32_80": 512,
  60. "fp64_80": 64,
  61. # Hopper Architecture
  62. "int8_90": 4096,
  63. "fp8_90": 4096,
  64. "fp16_90": 2048,
  65. "fp32_90": 1024,
  66. "fp64_90": 128,
  67. # Blackwell Architecture
  68. "int8_100": 8192,
  69. "fp8_100": 8192,
  70. "fp16_100": 4096,
  71. "fp32_100": 2048,
  72. }
  73. dict_key = ""
  74. if data_type is torch.float16:
  75. dict_key = f"fp16_{self.compute_capability}"
  76. elif data_type is torch.bfloat16:
  77. # FP16 and BF16 are equivalent in terms of FLOPs per cycle per SM
  78. dict_key = f"fp16_{self.compute_capability}"
  79. elif data_type is torch.float32:
  80. dict_key = f"fp32_{self.compute_capability}"
  81. elif data_type is torch.int8:
  82. dict_key = f"int8_{self.compute_capability}"
  83. elif data_type is torch.float64:
  84. dict_key = f"fp64_{self.compute_capability}"
  85. else:
  86. dict_key = "unknown"
  87. if dict_key not in hardcoded_device_values:
  88. raise RuntimeError(
  89. f"No data for sm_{self.compute_capability} and {data_type}."
  90. )
  91. return hardcoded_device_values[dict_key]
  92. def get_tflops_per_second(
  93. self, data_type: dtype, use_tensor_cores: bool = True
  94. ) -> float:
  95. num_sms = self.device_properties.multi_processor_count
  96. clock_rate = self.device_properties.clock_rate # KHz
  97. fma_per_cycle = 0
  98. if use_tensor_cores:
  99. fma_per_cycle = self.get_fma_per_cycle_per_sm_tensor_cores(data_type)
  100. else:
  101. fma_per_cycle = self.get_fma_per_cycle_per_sm_cuda_cores(data_type)
  102. # 1 FMA counts as 2 floating point operations
  103. # Clock rate is in KHz
  104. tflops_per_second = num_sms * fma_per_cycle * 2 * clock_rate / 1e9
  105. return tflops_per_second
  106. def get_memory_bandwidth_Bps(self) -> int:
  107. # DRAM devices are Double-Data which means they provide an output at both fronts of
  108. # a clock beat
  109. bus_bytes_per_cycle = int(2 * self.device_properties.memory_bus_width / 8)
  110. mem_clock_rate_Hz = self.device_properties.memory_clock_rate * 1000
  111. bytes_per_second = bus_bytes_per_cycle * mem_clock_rate_Hz
  112. return bytes_per_second
  113. def get_shared_memory_bandwidth_Bps(self) -> int:
  114. # Each warp can LD or ST 32 x 4 bytes per cycle. To calculate the
  115. # device's throughput we need to multiply with frequency and number of SMs.
  116. num_sms = self.device_properties.multi_processor_count
  117. bytes_per_cycle_per_sm = 128
  118. bytes_per_cycle_per_device = num_sms * bytes_per_cycle_per_sm
  119. bytes_per_second = (
  120. bytes_per_cycle_per_device * self.device_properties.clock_rate * 1000
  121. )
  122. return bytes_per_second