_quantized_conversions.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. # mypy: allow-untyped-defs
  2. import torch
  3. # Pack pairs of int4 values into int8, in row major order; first int4
  4. # value goes into lower order bits, and second int4 value into higher
  5. # order bits of resulting int8 value.
  6. def pack_int4_to_int8(weight):
  7. if weight.dim() != 2:
  8. raise AssertionError(f"weight must be 2D, got {weight.dim()}D")
  9. if weight.shape[1] % 2 != 0:
  10. raise AssertionError(f"weight.shape[1] must be even, got {weight.shape[1]}")
  11. if weight.dtype != torch.int8:
  12. raise AssertionError(f"weight.dtype must be int8, got {weight.dtype}")
  13. return ((weight[:, 1::2] & 0xF) << 4) | (weight[:, 0::2] & 0xF)
  14. # Unpack quandruples of bits in int8 values into int4 values, in row
  15. # major order; lower 4 bits go into first int4 value goes, and upper 4
  16. # bits go into second int4 value.
  17. def unpack_int8_to_int4(weight):
  18. if weight.dim() != 2:
  19. raise AssertionError(f"weight must be 2D, got {weight.dim()}D")
  20. if weight.dtype != torch.int8:
  21. raise AssertionError(f"weight.dtype must be int8, got {weight.dtype}")
  22. return torch.stack((weight & 0xF, (weight >> 4) & 0xF), dim=2).view(
  23. weight.shape[0], 2 * weight.shape[1]
  24. )
  25. # Transpose the weight matrix, and then reorder its elements according
  26. # to underlying requirements of CUTLASS library, so that it could be
  27. # used for CUTLASS-based mixed datatypes linear operation.
  28. def quantized_weight_reorder_for_mixed_dtypes_linear_cutlass(
  29. weight, dtypeq, transpose=False
  30. ):
  31. if weight.dim() != 2:
  32. raise AssertionError(f"weight must be 2D, got {weight.dim()}D")
  33. if weight.dtype != torch.int8:
  34. raise AssertionError(f"weight.dtype must be int8, got {weight.dtype}")
  35. if dtypeq != torch.int8 and dtypeq != torch.quint4x2:
  36. raise AssertionError(f"dtypeq must be int8 or quint4x2, got {dtypeq}")
  37. if weight.device.type != "cuda":
  38. raise AssertionError(f"weight must be on CUDA, got {weight.device.type}")
  39. device = weight.device
  40. # subbyte_transpose
  41. if not transpose:
  42. if dtypeq == torch.int8:
  43. outp = weight.T
  44. elif dtypeq == torch.quint4x2:
  45. outp = pack_int4_to_int8(unpack_int8_to_int4(weight.view(torch.int8)).T)
  46. else:
  47. outp = weight
  48. ncols, nrows = outp.shape # type: ignore[possibly-undefined]
  49. divisor = 32 if dtypeq == torch.quint4x2 else 64
  50. if nrows % divisor != 0:
  51. raise AssertionError(f"nrows must be divisible by {divisor}, got {nrows}")
  52. if ncols % 64 != 0:
  53. raise AssertionError(f"ncols must be divisible by 64, got {ncols}")
  54. # permute_B_rows_for_mixed_gemm
  55. # (permute cols actually, as transpose is applied first here)
  56. if dtypeq == torch.quint4x2:
  57. cols_permuted = (
  58. torch.tensor(
  59. [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15],
  60. device=device,
  61. )
  62. + (torch.arange(0, nrows // 16, device=device).reshape(-1, 1) * 16).expand(
  63. nrows // 16, 16
  64. )
  65. ).view(-1)
  66. else:
  67. cols_permuted = (
  68. torch.tensor(
  69. [0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15],
  70. device=device,
  71. )
  72. + (torch.arange(0, nrows // 16, device=device).reshape(-1, 1) * 16).expand(
  73. nrows // 16, 16
  74. )
  75. ).view(-1)
  76. # pyrefly: ignore [unbound-name]
  77. outp = outp.index_copy(1, cols_permuted, outp)
  78. # interleave_column_major_tensor
  79. magic0 = 4 if dtypeq == torch.quint4x2 else 2
  80. magic1 = 32 // magic0
  81. tmp0 = (
  82. (torch.arange(0, ncols // magic0, device=device) * (nrows // 4 * magic0))
  83. .view(-1, 1)
  84. .repeat(1, nrows // 4 * magic0)
  85. .view(-1)
  86. )
  87. tmp1 = (
  88. (torch.arange(0, nrows // 4 // magic1, device=device) * (magic0 * magic1))
  89. .view(-1, 1)
  90. .repeat(1, magic1)
  91. .view(-1)
  92. .repeat(ncols)
  93. )
  94. tmp2 = (
  95. (torch.arange(0, magic0, device=device) * magic1)
  96. .view(-1, 1)
  97. .repeat(1, nrows // 4)
  98. .view(-1)
  99. .repeat(ncols // magic0)
  100. )
  101. tmp3 = torch.arange(0, magic1, device=device).repeat(nrows // 4 * ncols // magic1)
  102. outp_offsets = tmp0 + tmp1 + tmp2 + tmp3
  103. tmp = outp.view(-1).view(torch.int32)
  104. outp = torch.zeros_like(tmp)
  105. outp.scatter_(0, outp_offsets, tmp)
  106. outp = outp.view(weight.dtype)
  107. # add_bias_and_interleave_quantized_tensor_inplace
  108. tmp = outp.view(-1)
  109. outp = torch.empty_like(tmp)
  110. if dtypeq == torch.int8:
  111. tmp = (tmp.to(torch.int) + 128).to(tmp.dtype)
  112. outp[0::4] = tmp[0::4]
  113. outp[1::4] = tmp[2::4]
  114. outp[2::4] = tmp[1::4]
  115. outp[3::4] = tmp[3::4]
  116. elif dtypeq == torch.quint4x2:
  117. tmp0 = ((tmp & 0xF) + 8) & 0xF
  118. tmp0 = (tmp0[1::2] << 4) | tmp0[0::2]
  119. tmp1 = (((tmp >> 4) & 0xF) + 8) & 0xF
  120. tmp1 = (tmp1[1::2] << 4) | tmp1[0::2]
  121. outp[0::4] = tmp0[0::2]
  122. outp[1::4] = tmp0[1::2]
  123. outp[2::4] = tmp1[0::2]
  124. outp[3::4] = tmp1[1::2]
  125. if dtypeq == torch.quint4x2:
  126. nrows *= 2
  127. ncols //= 2
  128. return outp.view(nrows, ncols).view(torch.uint8)