higgs.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632
  1. # Copyright 2024 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. "HIGGS through FLUTE (Flexible Lookup Table Engine for LUT-quantized LLMs) integration file"
  15. from math import sqrt
  16. from ..quantizers.quantizers_utils import should_convert_module
  17. from ..utils import is_flute_available, is_hadamard_available, is_torch_available, logging
  18. if is_torch_available():
  19. import torch
  20. import torch.nn as nn
  21. if is_flute_available():
  22. from flute.integrations.higgs import prepare_data_transposed
  23. from flute.tune import TuneMetaData, qgemm_v2
  24. if is_hadamard_available():
  25. from fast_hadamard_transform import hadamard_transform
  26. logger = logging.get_logger(__name__)
  27. def pad_to_block(tensor, dims, had_block_size, value=0):
  28. pad_dims = [0 for _ in range(2 * len(tensor.shape))]
  29. for dim in dims:
  30. size = tensor.shape[dim]
  31. next_multiple_of_1024 = ((size - 1) // had_block_size + 1) * had_block_size
  32. delta = next_multiple_of_1024 - size
  33. pad_dims[-2 * dim - 1] = delta
  34. return nn.functional.pad(tensor, pad_dims, "constant", value)
  35. def get_higgs_grid(p: int, n: int) -> "torch.Tensor":
  36. if (p, n) == (2, 256):
  37. return torch.tensor(
  38. [
  39. [-2.501467704772949, 0.17954708635807037],
  40. [-0.6761789321899414, 1.2728623151779175],
  41. [-1.8025816679000854, 0.7613157629966736],
  42. [-0.538287878036499, -2.6028504371643066],
  43. [0.8415029644966125, -0.8600977659225464],
  44. [0.7023013234138489, 3.3138747215270996],
  45. [0.5699077844619751, 2.5782253742218018],
  46. [3.292393207550049, -0.6016128063201904],
  47. [0.5561617016792297, -1.7723814249038696],
  48. [-2.1012380123138428, 0.020958125591278076],
  49. [0.46085724234580994, 0.8428705334663391],
  50. [1.4548040628433228, -0.6156039237976074],
  51. [3.210029363632202, 0.3546904921531677],
  52. [0.8893890976905823, -0.5967988967895508],
  53. [0.8618854284286499, -3.2061192989349365],
  54. [1.1360996961593628, -0.23852407932281494],
  55. [1.6646337509155273, -0.9265465140342712],
  56. [1.4767773151397705, 1.2476022243499756],
  57. [-1.0511897802352905, 1.94503915309906],
  58. [-1.56318998336792, -0.3264186680316925],
  59. [-0.1829211413860321, 0.2922491431236267],
  60. [-0.8950616717338562, -1.3887052536010742],
  61. [-0.08206957578659058, -1.329533576965332],
  62. [-0.487422913312912, 1.4817842245101929],
  63. [-1.6769757270812988, -2.8269758224487305],
  64. [-1.5057679414749146, 1.8905963897705078],
  65. [1.8335362672805786, 1.0515104532241821],
  66. [0.3273945450782776, 1.0491033792495728],
  67. [-3.295924186706543, -0.7021600008010864],
  68. [-1.8428784608840942, -1.2315762042999268],
  69. [-0.8575026392936707, -1.7005949020385742],
  70. [-1.120667815208435, 0.6467998027801514],
  71. [-0.1588846743106842, -1.804071068763733],
  72. [-0.8539647459983826, 0.5645008683204651],
  73. [-1.4192019701004028, -0.6175029873847961],
  74. [1.0799058675765991, 1.7871345281600952],
  75. [1.171311855316162, 0.7511613965034485],
  76. [2.162078380584717, 0.8044339418411255],
  77. [1.3969420194625854, -1.243762493133545],
  78. [-0.23818807303905487, 0.053944624960422516],
  79. [2.304199457168579, -1.2667627334594727],
  80. [1.4225027561187744, 0.568610668182373],
  81. [0.376836895942688, -0.7134661674499512],
  82. [2.0404467582702637, 0.4087389409542084],
  83. [0.7639489769935608, -1.1367933750152588],
  84. [0.3622530400753021, -1.4827953577041626],
  85. [0.4100743532180786, 0.36108437180519104],
  86. [-1.5867475271224976, -1.618212342262268],
  87. [-2.2769672870635986, -1.2132309675216675],
  88. [0.9184022545814514, -0.34428009390830994],
  89. [-0.3902314603328705, 0.21785245835781097],
  90. [3.120687484741211, 1.3077973127365112],
  91. [1.587440848350525, -1.6506884098052979],
  92. [-1.718808889389038, -0.038405973464250565],
  93. [-0.6888407468795776, -0.8402308821678162],
  94. [-0.7981445789337158, -1.1117373704910278],
  95. [-2.4124443531036377, 1.3419722318649292],
  96. [-0.6611530184745789, 0.9939885139465332],
  97. [-0.33103418350219727, -0.16702833771705627],
  98. [-2.4091389179229736, -2.326857566833496],
  99. [1.6610108613967896, -2.159703254699707],
  100. [0.014884627424180508, 0.3887578248977661],
  101. [0.029668325558304787, 1.8786455392837524],
  102. [1.180362582206726, 2.699317216873169],
  103. [1.821286678314209, -0.5960053205490112],
  104. [-0.44835323095321655, 3.327436685562134],
  105. [-0.3714401423931122, -2.1466753482818604],
  106. [-1.1103475093841553, -2.4536871910095215],
  107. [-0.39110705256462097, 0.6670510172843933],
  108. [0.474752813577652, -1.1959707736968994],
  109. [-0.013110585510730743, -2.52519154548645],
  110. [-2.0836575031280518, -1.703289270401001],
  111. [-1.1077687740325928, -0.1252644956111908],
  112. [-0.4138077199459076, 1.1837692260742188],
  113. [-1.977599024772644, 1.688241720199585],
  114. [-1.659559965133667, -2.1387736797332764],
  115. [0.03242531046271324, 0.6526556015014648],
  116. [0.9127950072288513, 0.6099498867988586],
  117. [-0.38478314876556396, 0.433487206697464],
  118. [0.27454206347465515, -0.27719801664352417],
  119. [0.10388526320457458, 2.2812814712524414],
  120. [-0.014394169673323631, -3.177137613296509],
  121. [-1.2871228456497192, -0.8961855173110962],
  122. [0.5720916986465454, -0.921597957611084],
  123. [1.1159656047821045, -0.7609877586364746],
  124. [2.4383342266082764, -2.2983546257019043],
  125. [-0.294057160615921, -0.9770799875259399],
  126. [-0.9342701435089111, 1.107579231262207],
  127. [-1.549338698387146, 3.090520143508911],
  128. [2.6076579093933105, 2.051239013671875],
  129. [-0.9259037375450134, 1.407211184501648],
  130. [-0.1747353971004486, 0.540488600730896],
  131. [-0.8963701725006104, 0.8271111249923706],
  132. [0.6480194926261902, 1.0128909349441528],
  133. [0.980783998966217, -0.06156221032142639],
  134. [-0.16883476078510284, 1.0601658821105957],
  135. [0.5839992761611938, 0.004697148688137531],
  136. [-0.34228450059890747, -1.2423977851867676],
  137. [2.500824451446533, 0.3665279746055603],
  138. [-0.17641609907150269, 1.3529551029205322],
  139. [0.05378641560673714, 2.817232847213745],
  140. [-1.2391047477722168, 2.354328155517578],
  141. [0.630434513092041, -0.668536365032196],
  142. [1.7576488256454468, 0.6738647818565369],
  143. [0.4435231387615204, 0.6000469326972961],
  144. [-0.08794835954904556, -0.11511358618736267],
  145. [1.6540337800979614, 0.33995017409324646],
  146. [-0.04202975332736969, -0.5375117063522339],
  147. [-0.4247745871543884, -0.7897617220878601],
  148. [0.06695003807544708, 1.2000739574432373],
  149. [-3.2508881092071533, 0.28734830021858215],
  150. [-1.613816261291504, 0.4944162368774414],
  151. [1.3598989248275757, 0.26117825508117676],
  152. [2.308382511138916, 1.3462618589401245],
  153. [-1.2137469053268433, -1.9254342317581177],
  154. [-0.4889402985572815, 1.8136259317398071],
  155. [-0.1870335340499878, -0.3480615019798279],
  156. [1.0766386985778809, -1.0627082586288452],
  157. [0.4651014506816864, 2.131748914718628],
  158. [-0.1306295394897461, -0.7811847925186157],
  159. [0.06433182954788208, -1.5397958755493164],
  160. [-0.2894323468208313, -0.5789554715156555],
  161. [-0.6081662178039551, 0.4845278263092041],
  162. [2.697964668273926, -0.18515698611736298],
  163. [0.1277363896369934, -0.7221432328224182],
  164. [0.8700758218765259, 0.35042452812194824],
  165. [0.22088994085788727, 0.495242178440094],
  166. [-2.5843818187713623, -0.8000828623771667],
  167. [0.6732649803161621, -1.4362232685089111],
  168. [-1.5286413431167603, 1.0417330265045166],
  169. [-1.1222513914108276, -0.6269875764846802],
  170. [-0.9752035140991211, -0.8750635385513306],
  171. [-2.6369473934173584, 0.6918523907661438],
  172. [0.14478731155395508, -0.041986867785453796],
  173. [-1.5629483461380005, 1.4369450807571411],
  174. [0.38952457904815674, -2.16428804397583],
  175. [-0.16885095834732056, 0.7976621985435486],
  176. [-3.12416934967041, 1.256506085395813],
  177. [0.6843105554580688, -0.4203019142150879],
  178. [1.9345275163650513, 1.934950351715088],
  179. [0.012184220366179943, -2.1080918312072754],
  180. [-0.6350273489952087, 0.7358828186988831],
  181. [-0.837304949760437, -0.6214472651481628],
  182. [0.08211923390626907, -0.9472538232803345],
  183. [2.9332995414733887, -1.4956780672073364],
  184. [1.3806978464126587, -0.2916182279586792],
  185. [0.06773144006729126, 0.9285762310028076],
  186. [-1.1943119764328003, 1.5963770151138306],
  187. [1.6395620107650757, -0.32285431027412415],
  188. [-1.390851378440857, -0.08273141086101532],
  189. [1.816330909729004, -1.2812227010726929],
  190. [0.7921574711799622, -2.1135804653167725],
  191. [0.5817914605140686, 1.2644577026367188],
  192. [1.929347038269043, -0.2386285960674286],
  193. [0.8877345323562622, 1.190008521080017],
  194. [1.4732073545455933, 0.8935023546218872],
  195. [-2.8518524169921875, -1.5478795766830444],
  196. [0.2439267635345459, 0.7576767802238464],
  197. [0.5246709585189819, -2.606659412384033],
  198. [1.150876760482788, 1.4073830842971802],
  199. [-0.2643202245235443, 2.0634236335754395],
  200. [1.555483341217041, -0.0023102816194295883],
  201. [2.0830578804016113, -1.7225427627563477],
  202. [-0.5424830317497253, -1.070199728012085],
  203. [0.9168899655342102, 0.8955540060997009],
  204. [-0.8120972514152527, 2.696739912033081],
  205. [-0.29908373951911926, -1.5310651063919067],
  206. [1.2320337295532227, -1.556247353553772],
  207. [1.8612544536590576, 0.08704725652933121],
  208. [0.22133447229862213, -1.8091708421707153],
  209. [-0.4403655230998993, -0.38571012020111084],
  210. [-1.88539457321167, 1.192205786705017],
  211. [2.239687919616699, 0.004709010478109121],
  212. [1.139495611190796, 0.45733731985092163],
  213. [-1.507995367050171, 0.19716016948223114],
  214. [0.46986445784568787, 1.5422041416168213],
  215. [-1.2573751211166382, -0.35984551906585693],
  216. [-1.7415345907211304, -0.6020717024803162],
  217. [1.0751984119415283, 0.19006384909152985],
  218. [2.24186635017395, -0.46343153715133667],
  219. [0.3610347509384155, -0.07658443599939346],
  220. [-1.3111497163772583, 0.432013601064682],
  221. [0.6164408326148987, 0.24538464844226837],
  222. [-1.9266542196273804, -0.3256155550479889],
  223. [-0.5870336890220642, -0.1879584938287735],
  224. [-1.0476511716842651, 0.3677721917629242],
  225. [-1.229940414428711, 1.2433830499649048],
  226. [0.18550436198711395, 0.22753673791885376],
  227. [-0.017921989783644676, 0.12625974416732788],
  228. [1.1659504175186157, -0.5020995736122131],
  229. [-0.5983408093452454, -1.40438973903656],
  230. [0.7519024014472961, -0.16282692551612854],
  231. [0.9920787811279297, -1.344896912574768],
  232. [-0.8103678226470947, 0.3064485788345337],
  233. [0.6956969499588013, 1.8208192586898804],
  234. [-2.7830491065979004, -0.2299390584230423],
  235. [-0.34681546688079834, 2.4890666007995605],
  236. [-1.4452646970748901, -1.2216600179672241],
  237. [-2.1872897148132324, 0.8926076292991638],
  238. [1.706072211265564, -2.8440372943878174],
  239. [1.1119003295898438, -2.4923460483551025],
  240. [-2.582794666290283, 2.0973289012908936],
  241. [0.04987720400094986, -0.2964983284473419],
  242. [-2.063807487487793, -0.7847916483879089],
  243. [-0.4068813621997833, 0.9135897755622864],
  244. [-0.9814359545707703, -0.3874954879283905],
  245. [-1.4227229356765747, 0.7337291240692139],
  246. [0.3065044581890106, 1.3125417232513428],
  247. [1.2160996198654175, -1.9643305540084839],
  248. [-1.2163853645324707, 0.14608727395534515],
  249. [-2.3030710220336914, -0.37558120489120483],
  250. [0.9232977628707886, 2.1843791007995605],
  251. [-0.1989777386188507, 1.651851773262024],
  252. [-0.714374840259552, -0.39365994930267334],
  253. [-0.7805715799331665, -2.099881887435913],
  254. [0.9015759229660034, -1.7053706645965576],
  255. [0.1033422127366066, 1.5256654024124146],
  256. [-1.8773194551467896, 2.324174165725708],
  257. [1.9227174520492554, 2.7441604137420654],
  258. [-0.5994020104408264, 0.23984014987945557],
  259. [1.3496100902557373, -0.9126054644584656],
  260. [-0.8765304088592529, -3.1877026557922363],
  261. [-1.2040035724639893, -1.5169521570205688],
  262. [1.4261796474456787, 2.150200128555298],
  263. [1.463774561882019, 1.6656692028045654],
  264. [0.20364105701446533, -0.4988172650337219],
  265. [0.5195154547691345, -0.24067887663841248],
  266. [-1.1116786003112793, -1.1599653959274292],
  267. [-0.8490808606147766, -0.1681060940027237],
  268. [0.3189965784549713, -0.9641751646995544],
  269. [-0.5664751529693604, -0.5951744318008423],
  270. [-1.6347930431365967, -0.9137664437294006],
  271. [0.44048091769218445, -0.47259435057640076],
  272. [-2.147747039794922, 0.47442489862442017],
  273. [1.834734320640564, 1.4462147951126099],
  274. [1.1777573823928833, 1.0659226179122925],
  275. [-0.9568989872932434, 0.09495053440332413],
  276. [-1.838529348373413, 0.2950586676597595],
  277. [-0.4800611734390259, 0.014894310384988785],
  278. [-0.5235516428947449, -1.7687653303146362],
  279. [2.0735011100769043, -0.8825281262397766],
  280. [2.637502431869507, 0.8455678224563599],
  281. [2.606602907180786, -0.7848446369171143],
  282. [-1.1886937618255615, 0.9330510497093201],
  283. [0.38082656264305115, 0.13328030705451965],
  284. [0.6847941875457764, 0.7384101152420044],
  285. [1.2638574838638306, -0.007309418171644211],
  286. [0.18292222917079926, -1.22371244430542],
  287. [0.8143821954727173, 1.4976691007614136],
  288. [0.6571850776672363, 0.48368802666664124],
  289. [-0.6991601586341858, 2.150190830230713],
  290. [0.8101756572723389, 0.10206498205661774],
  291. [-0.08768226951360703, -1.084917664527893],
  292. [-0.7208092212677002, 0.03657956421375275],
  293. [0.3211449086666107, 1.803687334060669],
  294. [-0.7835946083068848, 1.6869111061096191],
  295. ]
  296. )
  297. if (p, n) == (2, 64):
  298. return torch.tensor(
  299. [
  300. [-2.7216711044311523, 0.14431366324424744],
  301. [-0.766914427280426, 1.7193410396575928],
  302. [-2.2575762271881104, 1.2476624250411987],
  303. [1.233758807182312, -2.3560616970062256],
  304. [0.8701965808868408, -0.2649352252483368],
  305. [1.4506438970565796, 2.1776366233825684],
  306. [-0.06305818259716034, 1.9049758911132812],
  307. [2.536226511001587, 0.563927412033081],
  308. [0.4599496126174927, -1.8745561838150024],
  309. [-1.900517225265503, -0.30703988671302795],
  310. [0.09386251866817474, 0.8755807280540466],
  311. [1.946500539779663, -0.6743080615997314],
  312. [2.1338934898376465, 1.4581491947174072],
  313. [0.9429940581321716, -0.8038390278816223],
  314. [2.0697755813598633, -1.614896535873413],
  315. [0.772676408290863, 0.22017823159694672],
  316. [1.0689979791641235, -1.525044322013855],
  317. [0.6813604831695557, 1.1345642805099487],
  318. [0.4706456661224365, 2.606626272201538],
  319. [-1.294018030166626, -0.4372096061706543],
  320. [-0.09134224057197571, 0.4610418677330017],
  321. [-0.7907772064208984, -0.48412787914276123],
  322. [0.060459110885858536, -0.9172890186309814],
  323. [-0.5855047702789307, 2.56172513961792],
  324. [0.11484206467866898, -2.659848213195801],
  325. [-1.5893300771713257, 2.188580274581909],
  326. [1.6750942468643188, 0.7089915871620178],
  327. [-0.445697546005249, 0.7452405095100403],
  328. [-1.8539940118789673, -1.8377939462661743],
  329. [-1.5791912078857422, -1.017285943031311],
  330. [-1.030419945716858, -1.5746369361877441],
  331. [-1.9511750936508179, 0.43696075677871704],
  332. [-0.3446580767631531, -1.8953213691711426],
  333. [-1.4219647645950317, 0.7676230669021606],
  334. [-0.9191089272499084, 0.5021472573280334],
  335. [0.20464491844177246, 1.3684605360031128],
  336. [0.5402919054031372, 0.6699410676956177],
  337. [1.8903915882110596, 0.03638288006186485],
  338. [0.4723062515258789, -0.6216739416122437],
  339. [-0.41345009207725525, -0.22752176225185394],
  340. [2.7119064331054688, -0.5111885070800781],
  341. [1.065286636352539, 0.6950305700302124],
  342. [0.40629103779792786, -0.14339995384216309],
  343. [1.2815024852752686, 0.17108257114887238],
  344. [0.01785222627222538, -0.43778058886528015],
  345. [0.054590027779340744, -1.4225547313690186],
  346. [0.3076786696910858, 0.30697619915008545],
  347. [-0.9498570561408997, -0.9576997756958008],
  348. [-2.4640724658966064, -0.9660449028015137],
  349. [1.3714425563812256, -0.39760473370552063],
  350. [-0.4857747256755829, 0.2386789172887802],
  351. [1.2797833681106567, 1.3097363710403442],
  352. [0.5508887767791748, -1.1777795553207397],
  353. [-1.384316325187683, 0.1465839296579361],
  354. [-0.46556955575942993, -1.2442727088928223],
  355. [-0.3915477693080902, -0.7319604158401489],
  356. [-1.4005504846572876, 1.3890998363494873],
  357. [-0.8647305965423584, 1.0617644786834717],
  358. [-0.8901953101158142, -0.01650036871433258],
  359. [-0.9893633723258972, -2.4662880897521973],
  360. [1.445534110069275, -1.049334168434143],
  361. [-0.041650623083114624, 0.012734669260680676],
  362. [-0.3302375078201294, 1.26217782497406],
  363. [0.6934980154037476, 1.7714335918426514],
  364. ]
  365. )
  366. elif (p, n) == (2, 16):
  367. return torch.tensor(
  368. [
  369. [-0.8996632695198059, -1.6360418796539307],
  370. [-0.961183488368988, 1.5999565124511719],
  371. [-1.882026195526123, 0.678778350353241],
  372. [0.36300793290138245, -1.9667866230010986],
  373. [-0.6814072728157043, -0.576818585395813],
  374. [0.7270012497901917, 0.6186859607696533],
  375. [0.3359416127204895, 1.8371193408966064],
  376. [1.859930396080017, 0.036668598651885986],
  377. [0.17208248376846313, -0.9401724338531494],
  378. [-1.7599700689315796, -0.6244229674339294],
  379. [-0.8993809223175049, 0.32267823815345764],
  380. [0.839488685131073, -0.3017036020755768],
  381. [1.5314953327178955, 1.2942044734954834],
  382. [-0.0011779458727687597, 0.00022069070837460458],
  383. [1.4274526834487915, -1.207889199256897],
  384. [-0.16123905777931213, 0.8787511587142944],
  385. ]
  386. )
  387. elif (p, n) == (1, 16):
  388. return torch.tensor(
  389. [
  390. [-2.7325894832611084],
  391. [-2.069017171859741],
  392. [-1.6180464029312134],
  393. [-1.2562311887741089],
  394. [-0.9423404335975647],
  395. [-0.6567591428756714],
  396. [-0.38804829120635986],
  397. [-0.12839503586292267],
  398. [0.12839503586292267],
  399. [0.38804829120635986],
  400. [0.6567591428756714],
  401. [0.9423404335975647],
  402. [1.2562311887741089],
  403. [1.6180464029312134],
  404. [2.069017171859741],
  405. [2.7325894832611084],
  406. ]
  407. )
  408. elif (p, n) == (1, 8):
  409. return torch.tensor(
  410. [
  411. [-2.1519455909729004],
  412. [-1.3439092636108398],
  413. [-0.7560052871704102],
  414. [-0.2450941801071167],
  415. [0.2450941801071167],
  416. [0.7560052871704102],
  417. [1.3439092636108398],
  418. [2.1519455909729004],
  419. ]
  420. )
  421. elif (p, n) == (1, 4):
  422. return torch.tensor([[-1.5104175806045532], [-0.4527800381183624], [0.4527800381183624], [1.5104175806045532]])
  423. else:
  424. raise NotImplementedError(f"Unsupported p={p}, n={n}")
  425. def quantize_with_higgs(weight, bits: int = 4, p: int = 2, group_size: int = 256, hadamard_size: int = 1024):
  426. assert len(weight.shape) == 2, "Only 2D weights are supported for now"
  427. grid = get_higgs_grid(p, 2 ** (p * bits)).to(weight.device)
  428. grid_norm_2 = torch.linalg.norm(grid, axis=-1) ** 2
  429. device = weight.device
  430. dtype = weight.dtype
  431. weight = weight.to(copy=True, dtype=torch.float32)
  432. # Pad to Hadamard transform size
  433. weight = pad_to_block(weight, [1], hadamard_size)
  434. # Scale and Hadamard transform
  435. mult = weight.shape[1] // hadamard_size
  436. weight = weight.reshape(-1, mult, hadamard_size)
  437. scales = torch.linalg.norm(weight, axis=-1)
  438. weight = hadamard_transform(weight, 1) / scales[:, :, None]
  439. # Pad to edenn_d and project
  440. weight = pad_to_block(weight, [2], p).reshape(weight.shape[0], mult, -1, p)
  441. # Quantize
  442. codes = torch.empty(weight.shape[:-1], device=device, dtype=torch.uint8)
  443. for i in range(0, weight.shape[0], 16):
  444. codes[i : i + 16] = torch.argmax(2 * weight[i : i + 16] @ grid.T - grid_norm_2, dim=-1).to(torch.uint8)
  445. del weight
  446. codes = codes.reshape(codes.shape[0], -1)
  447. scales = scales / sqrt(hadamard_size)
  448. weight, scales, tables, tables2, tune_metadata = prepare_data_transposed(
  449. codes,
  450. torch.repeat_interleave(scales.to(dtype), hadamard_size // group_size, dim=1),
  451. grid.to(dtype),
  452. num_bits=bits,
  453. group_size=group_size,
  454. vector_size=p,
  455. dtype=dtype,
  456. device=device,
  457. check_correctness=False,
  458. )
  459. return {
  460. "weight": weight,
  461. "scales": scales,
  462. "tables": tables,
  463. "tables2": tables2.view(dtype=torch.float16),
  464. "tune_metadata": tune_metadata,
  465. }
  466. class HiggsLinear(torch.nn.Module):
  467. def __init__(
  468. self,
  469. in_features: int,
  470. out_features: int,
  471. num_bits: int,
  472. bias=True,
  473. dtype: torch.dtype | None = None,
  474. device: torch.device | None = None,
  475. group_size: int = 256,
  476. hadamard_size: int = 1024,
  477. ):
  478. super().__init__()
  479. self.in_features = in_features
  480. self.out_features = out_features
  481. self.num_bits = num_bits
  482. self.group_size = group_size
  483. self.hadamard_size = hadamard_size
  484. assert in_features % group_size == 0
  485. assert num_bits in [2, 3, 4]
  486. self.weight = nn.Parameter(
  487. torch.empty((out_features * num_bits // 16, in_features), dtype=torch.int16, device=device),
  488. requires_grad=False,
  489. )
  490. self.scales = nn.Parameter(
  491. torch.empty((out_features, in_features // group_size), dtype=dtype, device=device), requires_grad=False
  492. )
  493. self.tables = nn.Parameter(torch.empty((2**num_bits,), dtype=dtype, device=device), requires_grad=False)
  494. self.tables2 = nn.Parameter(
  495. torch.empty((2**num_bits, 2**num_bits, 2), dtype=dtype, device=device), requires_grad=False
  496. )
  497. if bias:
  498. self.bias = nn.Parameter(torch.empty(out_features, device=device, dtype=dtype), requires_grad=False)
  499. else:
  500. self.register_parameter("bias", None)
  501. self.workspace = None # must be set externally to be reused among layers
  502. self.tune_metadata: TuneMetaData = None # must be set externally because architecture dependent
  503. def forward(self, x):
  504. x = pad_to_block(x, [-1], self.hadamard_size)
  505. if self.workspace is None:
  506. raise Exception("Workspace must be set before calling forward")
  507. return qgemm_v2(
  508. x,
  509. self.weight,
  510. self.scales,
  511. self.tables,
  512. self.tables2.view(dtype=torch.float32),
  513. self.workspace,
  514. self.tune_metadata,
  515. hadamard_size=self.hadamard_size,
  516. )
  517. def replace_with_higgs_linear(model, modules_to_not_convert: list[str] | None = None, quantization_config=None):
  518. """
  519. Public method that replaces the Linear layers of the given model with HIGGS quantized layers.
  520. Args:
  521. model (`torch.nn.Module`):
  522. The model to convert, can be any `torch.nn.Module` instance.
  523. modules_to_not_convert (`list[str]`, *optional*, defaults to `None`):
  524. A list of nn.Linear weights to not convert. If a parameter path is in the list (e.g. `lm_head.weight`), the corresponding module will not be
  525. converted.
  526. quantization_config (`HiggsConfig`):
  527. The quantization config object that contains the quantization parameters.
  528. """
  529. has_been_replaced = False
  530. # we need this to correctly materialize the weights during quantization
  531. for module_name, module in model.named_modules():
  532. if not should_convert_module(module_name, modules_to_not_convert):
  533. continue
  534. with torch.device("meta"):
  535. if isinstance(module, nn.Linear):
  536. new_module = HiggsLinear(
  537. module.in_features,
  538. module.out_features,
  539. bias=module.bias is not None,
  540. num_bits=quantization_config.bits,
  541. hadamard_size=quantization_config.hadamard_size,
  542. group_size=quantization_config.group_size,
  543. )
  544. new_module.source_cls = type(module)
  545. new_module.requires_grad_(False)
  546. model.set_submodule(module_name, new_module)
  547. has_been_replaced = True
  548. if not has_been_replaced:
  549. logger.warning(
  550. "You are loading your model using eetq but no linear modules were found in your model."
  551. " Please double check your model architecture, or submit an issue on github if you think this is"
  552. " a bug."
  553. )
  554. return model
  555. def dequantize_higgs(model, current_key_name=None):
  556. """
  557. Dequantizes the HiggsLinear layers in the given model by replacing them with standard torch.nn.Linear layers.
  558. Args:
  559. model (torch.nn.Module): The model containing HiggsLinear layers to be dequantized.
  560. current_key_name (list, optional): A list to keep track of the current module names during recursion. Defaults to None.
  561. Returns:
  562. torch.nn.Module: The model with HiggsLinear layers replaced by torch.nn.Linear layers.
  563. """
  564. with torch.no_grad():
  565. for name, module in model.named_children():
  566. if current_key_name is None:
  567. current_key_name = []
  568. current_key_name.append(name)
  569. if isinstance(module, HiggsLinear):
  570. in_features = module.in_features
  571. out_features = module.out_features
  572. model._modules[name] = torch.nn.Linear(
  573. in_features,
  574. out_features,
  575. bias=module.bias is not None,
  576. device=module.scales.device,
  577. dtype=module.scales.dtype,
  578. )
  579. model._modules[name].weight.data = module(
  580. torch.eye(in_features, device=module.scales.device, dtype=module.scales.dtype)
  581. ).T.contiguous()
  582. if len(list(module.children())) > 0:
  583. _ = dequantize_higgs(
  584. module,
  585. current_key_name=current_key_name,
  586. )
  587. # Remove the last key for recursion
  588. current_key_name.pop(-1)
  589. return model