mxfp.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. """
  2. Helper classes for working with low precision floating point types that
  3. align with the opencompute (OCP) microscaling (MX) specification.
  4. * MXFP4Tensor: 4-bit E2M1 floating point data
  5. * MXScaleTensor: 8-bit E8M0 floating point data
  6. Reference: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
  7. """
  8. import torch
  9. class MXFP4Tensor:
  10. def __init__(self, data=None, size=None, device=None):
  11. """
  12. Tensor class for working with four bit E2M1 floating point data as defined by the
  13. opencompute microscaling specification.
  14. Parameters:
  15. - data: A torch tensor of float32 numbers to convert to fp4e2m1 microscaling format.
  16. - size: The size of the tensor to create.
  17. - device: The device on which to create the tensor.
  18. """
  19. self.device = device
  20. if data is not None:
  21. assert isinstance(data, torch.Tensor), "Parameter data must be a torch tensor"
  22. self.device = data.device
  23. self.data = self._from_float(data)
  24. elif size is not None:
  25. self.size = size if isinstance(size, tuple) else (size, )
  26. else:
  27. raise ValueError("Either parameter data or size must be provided")
  28. def random(self):
  29. S = torch.randint(0, 2, size=self.size, dtype=torch.uint8, device=self.device)
  30. E = torch.randint(0, 4, size=self.size, dtype=torch.uint8, device=self.device)
  31. M = torch.randint(0, 2, size=self.size, dtype=torch.uint8, device=self.device)
  32. self.data = ((S << 3) | (E << 1) | M).type(torch.uint8)
  33. return self
  34. def to(self, dtype):
  35. """
  36. Convert fp4e2m1 data to float32.
  37. Returns:
  38. - A torch tensor of type dtype representing the fp4e2m1 data.
  39. """
  40. assert dtype == torch.float32, "Currently only float32 is supported for fp4e2m1 to float conversion"
  41. data = self.data
  42. S = ((data >> 3) & 0x1).type(dtype)
  43. E = ((data >> 1) & 0x3).type(dtype)
  44. M = (data & 0x1).type(dtype)
  45. # The MXF4 E2M1 spec defines 0bS000 as zero
  46. value = torch.zeros_like(S)
  47. is_zero = (E == 0) & (M == 0)
  48. non_zero_mask = ~is_zero
  49. if non_zero_mask.any():
  50. S_nz = S[non_zero_mask]
  51. E_nz = E[non_zero_mask]
  52. M_nz = M[non_zero_mask]
  53. sign = torch.pow(-1, S_nz)
  54. # Normal and subnormal handling for the exponent and mantissa
  55. exponent = torch.where(E_nz == 0, E_nz, E_nz - 1)
  56. mantissa = torch.where(E_nz == 0, M_nz * 0.5, 1.0 + M_nz * 0.5)
  57. value_nz = sign * torch.pow(2, exponent) * mantissa
  58. value[non_zero_mask] = value_nz
  59. # For zeros, the values must remain zero with the correct sign
  60. value[is_zero & (S == 1)] *= -1
  61. return value.type(torch.float32)
  62. def _from_float(self, values):
  63. """
  64. Convert float32 numbers to mxf4 e2m1 format.
  65. * No encodings are reserved for Inf or NaN in mxf4.
  66. * Conversion from float supports roundTiesToEven rounding mode.
  67. * If a value exceeds the mxf4 representable range after rounding,
  68. clamps to the maximum mxf4 magnitude, preserving the sign.
  69. * If a value has magnitude less than the minimum subnormal magnitude
  70. in mxf4 after rounding, converts to zero.
  71. Parameters:
  72. - values: A torch tensor of float32 numbers to convert to fp4 format.
  73. """
  74. S = torch.signbit(values).type(torch.uint8)
  75. abs_values = torch.abs(values)
  76. is_zero = (abs_values == 0)
  77. is_invalid = torch.isnan(values) | torch.isinf(values)
  78. # Enumerate all possible E2M1 exponent and mantissa values. We will
  79. # use these to compare the distance between float32 and all possible
  80. # E2M1 floats to find the nearest E2M1 representable value
  81. E_bits = torch.tensor([0, 1, 2, 3], dtype=torch.uint8, device=self.device)
  82. M_bits = torch.tensor([0, 1], dtype=torch.uint8, device=self.device)
  83. candidate_values = []
  84. candidate_E = []
  85. candidate_M = []
  86. for E in E_bits:
  87. if E == 0:
  88. # Subnormals
  89. exponent = 0
  90. for M in M_bits:
  91. significand = M * 0.5
  92. value = significand * (2**exponent)
  93. candidate_values.append(value)
  94. candidate_E.append(E)
  95. candidate_M.append(M)
  96. else:
  97. # Normals
  98. exponent = E.item() - 1
  99. for M in M_bits:
  100. significand = 1.0 + M * 0.5
  101. value = significand * (2**exponent)
  102. candidate_values.append(value)
  103. candidate_E.append(E)
  104. candidate_M.append(M)
  105. candidates = torch.tensor(candidate_values, dtype=torch.float32, device=self.device)
  106. candidate_E = torch.tensor(candidate_E, dtype=torch.uint8, device=self.device)
  107. candidate_M = torch.tensor(candidate_M, dtype=torch.uint8, device=self.device)
  108. abs_values_flat = abs_values.view(-1)
  109. N = abs_values_flat.shape[0]
  110. abs_values_expanded = abs_values_flat.unsqueeze(1)
  111. # Clamp invalid values to the max e2m1 representable value
  112. max_candidate_value = candidates.max().item()
  113. abs_values_flat[is_invalid.view(-1)] = max_candidate_value
  114. # Compute distance between all abs_values and candidate e2m1 values
  115. errors = torch.abs(abs_values_expanded - candidates.unsqueeze(0))
  116. # To implement roundTiesToEven, we need to break ties by preferring
  117. # even mantissas (M == 0). We do so by adding an epsilon bias to shift
  118. # the closest candidate with an even mantissa closer to the float value
  119. min_errors, _ = torch.min(errors, dim=1, keepdim=True)
  120. is_tie = (errors == min_errors)
  121. # More than one candidate has the min error for some float value
  122. if is_tie.sum() > 1:
  123. M_bits_expanded = candidate_M.unsqueeze(0).expand(N, -1)
  124. tie_breaker = (M_bits_expanded == 0).type(torch.int32)
  125. errors = errors - (tie_breaker * 1e-6)
  126. best_indices = torch.argmin(errors, dim=1)
  127. E_selected = candidate_E[best_indices]
  128. M_selected = candidate_M[best_indices]
  129. E = E_selected.view(abs_values.shape)
  130. M = M_selected.view(abs_values.shape)
  131. E[is_zero] = 0
  132. M[is_zero] = 0
  133. return ((S << 3) | (E << 1) | M).type(torch.uint8)
  134. def to_packed_tensor(self, dim):
  135. """
  136. Packs two e2m1 elements into a single uint8 along the specified dimension.
  137. Parameters:
  138. - dim: The dimension along which to pack the elements.
  139. Returns:
  140. - A torch tensor of dtype uint8 with two e2m1 elements packed into one uint8.
  141. """
  142. data = self.data
  143. assert 0 <= dim < data.ndim, \
  144. "The dimension to pack along is not within the range of tensor dimensions"
  145. size_along_dim = data.size(dim)
  146. new_size_along_dim = (size_along_dim + 1) // 2
  147. # If the size is odd, we pad the data along dim with zeros at the end
  148. if size_along_dim % 2 != 0:
  149. pad_sizes = [0] * (2 * data.ndim)
  150. pad_index = (data.ndim - dim - 1) * 2 + 1
  151. pad_sizes[pad_index] = 1
  152. data = torch.nn.functional.pad(data, pad_sizes, mode='constant', value=0)
  153. new_shape = list(data.shape)
  154. new_shape[dim] = new_size_along_dim
  155. new_shape.insert(dim + 1, 2) # packed dimension of length 2
  156. data = data.reshape(*new_shape)
  157. low = data.select(dim + 1, 0)
  158. high = data.select(dim + 1, 1)
  159. packed = (high << 4) | low
  160. return packed
  161. def unpack_packed_tensor(self, packed_tensor, dim, original_shape):
  162. """
  163. Unpacks a tensor where two fp4 elements are packed into a single uint8.
  164. Parameters:
  165. - packed_tensor: The packed tensor
  166. - dim: The dimension along which the tensor was packed.
  167. - original_shape: The shape of the original tensor before packing.
  168. Returns:
  169. - A tensor with the original data unpacked into uint8 elements containing one
  170. fp4e2m1 element in the least significant bits.
  171. """
  172. high = (packed_tensor >> 4) & 0xF
  173. low = packed_tensor & 0xF
  174. stacked = torch.stack((low, high), dim=dim + 1)
  175. # Flatten along dim and dim+1 and then merge
  176. shape = list(stacked.shape)
  177. new_shape = shape[:dim] + [shape[dim] * 2] + shape[dim + 2:]
  178. data = stacked.reshape(*new_shape)
  179. # Remove any padding
  180. if original_shape[dim] % 2 != 0:
  181. indices = [slice(None)] * data.ndim
  182. indices[dim] = slice(0, original_shape[dim])
  183. data = data[tuple(indices)]
  184. return data.type(torch.uint8)
  185. class MXScaleTensor:
  186. def __init__(self, data=None, size=None, device=None):
  187. """
  188. Tensor class for working with microscaling E8M0 block scale factors.
  189. Parameters:
  190. - data: A torch tensor of float32 numbers to convert to fp8e8m0 microscaling format.
  191. - size: The size of the tensor to create.
  192. - device: The device on which to create the tensor.
  193. """
  194. self.device = device
  195. if data is not None:
  196. assert isinstance(data, torch.Tensor), "Parameter data must be a torch tensor"
  197. self.device = data.device
  198. self.data = self._from_float(data)
  199. elif size is not None:
  200. self.size = size if isinstance(size, tuple) else (size, )
  201. else:
  202. raise ValueError("Either parameter data or size must be provided")
  203. def random(self, low=None, high=None):
  204. """
  205. Generate random E8M0 data within a specified range.
  206. * Excludes the NaN encoding (255).
  207. """
  208. bias = 127
  209. min_exponent = 0 if low is None else max(0, int(torch.log2(torch.tensor(low))) + bias)
  210. max_exponent = 254 if high is None else min(254, max(0, int(torch.log2(torch.tensor(high))) + bias))
  211. assert min_exponent <= max_exponent, "Low must be less than or equal to high"
  212. E = torch.randint(min_exponent, max_exponent + 1, size=self.size, dtype=torch.uint8, device=self.device)
  213. self.data = E
  214. return self
  215. def to(self, dtype):
  216. assert dtype == torch.float32, "Currently only float32 is supported for f8e8m0 to float conversion"
  217. data = self.data.type(dtype)
  218. is_nan = (data == 255)
  219. e_biased = data.clone()
  220. e_biased[is_nan] = 0
  221. e = e_biased - 127
  222. value = torch.pow(2.0, e)
  223. value[is_nan] = torch.nan
  224. return value.type(dtype)
  225. def _from_float(self, values):
  226. """
  227. Convert float32 numbers to E8M0 format.
  228. * Values <= 0, NaNs, and Infs are converted to the NaN encoding (255).
  229. * Positive values are converted by computing the floor of log2(value) to get the exponent.
  230. Parameters:
  231. - values: A torch tensor of float32 numbers to convert to E8M0 format.
  232. """
  233. result = torch.empty_like(values, dtype=torch.uint8, device=self.device)
  234. is_invalid = torch.isnan(values) | torch.isinf(values) | (values <= 0)
  235. result[is_invalid] = 255
  236. valid_values = values[~is_invalid]
  237. e = torch.floor(torch.log2(valid_values))
  238. e_biased = e + 127
  239. e_biased_int = e_biased.type(torch.int32)
  240. e_biased_clamped = torch.clamp(e_biased_int, 0, 254)
  241. result[~is_invalid] = e_biased_clamped.type(torch.uint8)
  242. return result