utils.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. # Source:
  2. # https://github.com/kubernetes-client/python/blob/master/kubernetes/utils/quantity.py
  3. from decimal import Decimal, InvalidOperation
  4. from functools import reduce
  5. from typing import Optional
  6. # Mapping used to get generation for TPU-{accelerator}-head resource
  7. # https://cloud.google.com/kubernetes-engine/docs/how-to/tpus#run
  8. gke_tpu_accelerator_to_generation = {
  9. "tpu-v4-podslice": "v4",
  10. "tpu-v5-lite-device": "v5e",
  11. "tpu-v5-lite-podslice": "v5e",
  12. "tpu-v5p-slice": "v5p",
  13. "tpu-v6e-slice": "v6e",
  14. }
  15. def parse_quantity(quantity):
  16. """Parse kubernetes canonical form quantity like 200Mi to a decimal number.
  17. Supported SI suffixes:
  18. base1024: Ki | Mi | Gi | Ti | Pi | Ei
  19. base1000: n | u | m | "" | k | M | G | T | P | E
  20. See
  21. https://github.com/kubernetes/apimachinery/blob/master/pkg/api/resource/quantity.go
  22. Args:
  23. quantity: string. kubernetes canonical form quantity
  24. Returns:
  25. Decimal: The parsed quantity as a decimal number
  26. Raises:
  27. ValueError: On invalid or unknown input
  28. """
  29. if isinstance(quantity, (int, float, Decimal)):
  30. return Decimal(quantity)
  31. exponents = {
  32. "n": -3,
  33. "u": -2,
  34. "m": -1,
  35. "K": 1,
  36. "k": 1,
  37. "M": 2,
  38. "G": 3,
  39. "T": 4,
  40. "P": 5,
  41. "E": 6,
  42. }
  43. quantity = str(quantity)
  44. number = quantity
  45. suffix = None
  46. if len(quantity) >= 2 and quantity[-1] == "i":
  47. if quantity[-2] in exponents:
  48. number = quantity[:-2]
  49. suffix = quantity[-2:]
  50. elif len(quantity) >= 1 and quantity[-1] in exponents:
  51. number = quantity[:-1]
  52. suffix = quantity[-1:]
  53. try:
  54. number = Decimal(number)
  55. except InvalidOperation:
  56. raise ValueError("Invalid number format: {}".format(number))
  57. if suffix is None:
  58. return number
  59. if suffix.endswith("i"):
  60. base = 1024
  61. elif len(suffix) == 1:
  62. base = 1000
  63. else:
  64. raise ValueError("{} has unknown suffix".format(quantity))
  65. # handle SI inconsistency
  66. if suffix == "ki":
  67. raise ValueError("{} has unknown suffix".format(quantity))
  68. if suffix[0] not in exponents:
  69. raise ValueError("{} has unknown suffix".format(quantity))
  70. exponent = Decimal(exponents[suffix[0]])
  71. return number * (base**exponent)
  72. def tpu_node_selectors_to_type(topology: str, accelerator: str) -> Optional[str]:
  73. """Convert Kubernetes gke-tpu nodeSelectors to TPU accelerator_type
  74. for a kuberay TPU worker group.
  75. Args:
  76. topology: value of the cloud.google.com/gke-tpu-topology Kubernetes
  77. nodeSelector, describes the physical topology of the TPU podslice.
  78. accelerator: value of the cloud.google.com/gke-tpu-accelerator nodeSelector,
  79. the name of the TPU accelerator, e.g. tpu-v4-podslice
  80. Returns:
  81. A string, accelerator_type, e.g. "v4-8".
  82. """
  83. if topology and accelerator:
  84. generation = gke_tpu_accelerator_to_generation[accelerator]
  85. # Reduce e.g. "2x2x2" to 8
  86. chip_dimensions = [int(chip_count) for chip_count in topology.split("x")]
  87. num_chips = reduce(lambda x, y: x * y, chip_dimensions)
  88. default_num_cores_per_chip = 1
  89. if generation == "v4" or generation == "v5p":
  90. default_num_cores_per_chip = 2
  91. num_cores = num_chips * default_num_cores_per_chip
  92. return f"{generation}-{num_cores}"
  93. return None