_shape_functions.py 51 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633
  1. # mypy: allow-untyped-defs
  2. import math
  3. from collections.abc import Callable
  4. from typing import Any, Optional, Union
  5. number = Union[int, float]
  6. # flake8: noqa
  7. ###
  8. # There are generated files that depend on this file
  9. # To re-generate, please run from the root of the repo:
  10. # python torchgen/shape_functions/gen_jit_shape_functions.py
  11. # How to test:
  12. # After regenerating files, compile PyTorch.
  13. # Then run: ./build/bin/test_jit --gtest_filter=TestShapeGraphLinting.Basic
  14. # If you have enabled opinfo testing for the op, also run:
  15. # python test/test_ops_jit.py TestJitCPU.test_variant_consistency_jit_[FAILING_OP]_cpu_float32
  16. # to reproduce errors from opinfo tests.
  17. # Example PR: https://github.com/pytorch/pytorch/pull/80860/files
  18. ####
  19. import torch
  20. def broadcast(a: list[int], b: list[int]):
  21. dimsA = len(a)
  22. dimsB = len(b)
  23. ndim = max(dimsA, dimsB)
  24. expandedSizes: list[int] = []
  25. for i in range(ndim):
  26. offset = ndim - 1 - i
  27. dimA = dimsA - 1 - offset
  28. dimB = dimsB - 1 - offset
  29. sizeA = a[dimA] if (dimA >= 0) else 1
  30. sizeB = b[dimB] if (dimB >= 0) else 1
  31. if sizeA != sizeB and sizeA != 1 and sizeB != 1:
  32. # TODO: only assertion error is bound in C++ compilation right now
  33. raise AssertionError(
  34. f"The size of tensor a {sizeA} must match the size of tensor b ({sizeB}) at non-singleton dimension {i}"
  35. )
  36. expandedSizes.append(sizeB if sizeA == 1 else sizeA)
  37. return expandedSizes
  38. def broadcast_three(a: list[int], b: list[int], c: list[int]):
  39. return broadcast(broadcast(a, b), c)
  40. def broadcast_one_three(a: list[int], b: Any, c: list[int]):
  41. return broadcast(a, c)
  42. def adaptive_avg_pool2d(self: list[int], out: list[int]):
  43. if len(out) != 2:
  44. raise AssertionError(f"Expected out to have length 2, but got {len(out)}")
  45. if not (len(self) == 3 or len(self) == 4):
  46. raise AssertionError(
  47. f"Expected self to have length 3 or 4, but got {len(self)}"
  48. )
  49. for i in range(1, len(self)):
  50. if self[i] == 0:
  51. raise AssertionError(f"Expected self[{i}] to be non-zero, but got 0")
  52. shape: list[int] = []
  53. for i in range(0, len(self) - 2):
  54. shape.append(self[i])
  55. for elem in out:
  56. shape.append(elem)
  57. return shape
  58. def _copy(self: list[int]):
  59. out: list[int] = []
  60. for elem in self:
  61. out.append(elem)
  62. return out
  63. def unary(self: list[int]):
  64. return _copy(self)
  65. def broadcast_inplace(a: list[int], b: list[int]):
  66. dimsA = len(a)
  67. dimsB = len(b)
  68. if dimsB > dimsA:
  69. raise AssertionError(
  70. f"The dims of tensor b ({dimsB}) must be less than or equal to the dims of tensor a ({dimsA}) "
  71. )
  72. for dimA in range(dimsA):
  73. dimB = dimsB - dimsA + dimA
  74. sizeA = a[dimA]
  75. sizeB = b[dimB] if (dimB >= 0) else 1
  76. if sizeA != sizeB and sizeB != 1:
  77. # TODO: only assertion error is bound in C++ compilation right now
  78. raise AssertionError(
  79. "The size of tensor a {} must match the size of tensor b ("
  80. "{}) at non-singleton dimension {}".format(sizeA, sizeB, dimA)
  81. )
  82. return _copy(a)
  83. def expand(self: list[int], sizes: list[int]):
  84. if len(sizes) < len(self):
  85. raise AssertionError(
  86. f"Expected len(sizes) ({len(sizes)}) >= len(self) ({len(self)})"
  87. )
  88. ndim = len(sizes)
  89. tensor_dim = len(self)
  90. if ndim == 0:
  91. return _copy(sizes)
  92. out: list[int] = []
  93. for i in range(ndim):
  94. offset = ndim - 1 - i
  95. dim = tensor_dim - 1 - offset
  96. size = self[dim] if dim >= 0 else 1
  97. targetSize = sizes[i]
  98. if targetSize == -1:
  99. if dim < 0:
  100. raise AssertionError(f"Expected dim ({dim}) >= 0 when targetSize is -1")
  101. targetSize = size
  102. if size != targetSize:
  103. if size != 1:
  104. raise AssertionError(
  105. f"Expected size ({size}) == 1 when size != targetSize ({targetSize})"
  106. )
  107. size = targetSize
  108. out.append(size)
  109. return out
  110. def expand_one_unused(self: list[int], sizes: list[int], inp0: Any):
  111. return expand(self, sizes)
  112. def infer_size_impl(shape: list[int], numel: int) -> list[int]:
  113. newsize = 1
  114. infer_dim: Optional[int] = None
  115. for dim in range(len(shape)):
  116. if shape[dim] == -1:
  117. if infer_dim is not None:
  118. raise AssertionError("only one dimension can be inferred")
  119. infer_dim = dim
  120. elif shape[dim] >= 0:
  121. newsize *= shape[dim]
  122. else:
  123. raise AssertionError("invalid shape dimensions")
  124. if not (
  125. numel == newsize
  126. or (infer_dim is not None and newsize > 0 and numel % newsize == 0)
  127. ):
  128. raise AssertionError("invalid shape")
  129. out = _copy(shape)
  130. if infer_dim is not None:
  131. out[infer_dim] = numel // newsize
  132. return out
  133. def numel(sizes: list[int]):
  134. numel = 1
  135. for elem in sizes:
  136. numel *= elem
  137. return numel
  138. def view(self: list[int], sizes: list[int]):
  139. return infer_size_impl(sizes, numel(self))
  140. def view_one_unused(self: list[int], sizes: list[int], *, implicit: bool = False):
  141. return view(self, sizes)
  142. def sum_mean_dim(
  143. self: list[int], opt_dims: Optional[list[int]], keep_dim: bool, dt: Any
  144. ):
  145. out: list[int] = []
  146. if opt_dims is None or len(opt_dims) == 0:
  147. dims: list[int] = list(range(len(self)))
  148. else:
  149. dims = opt_dims
  150. for idx in range(len(self)):
  151. is_mean_dim: bool = False
  152. for reduce_dim in dims:
  153. if idx == maybe_wrap_dim(reduce_dim, len(self)):
  154. is_mean_dim = True
  155. if is_mean_dim:
  156. if keep_dim:
  157. out.append(1)
  158. else:
  159. out.append(self[idx])
  160. return out
  161. def max_dim(self: list[int], dim: int, keep_dim: bool):
  162. out = sum_mean_dim(self, [dim], keep_dim, None)
  163. return out, out
  164. # note: python already rounds down towards negative infinity on integer division, special arithmetic not needed
  165. def div_rtn(x: int, y: int):
  166. return x // y
  167. def pooling_output_shape_pad_lr(
  168. inputSize: int,
  169. kernelSize: int,
  170. pad_l: int,
  171. pad_r: int,
  172. stride: int,
  173. dilation: int,
  174. ceil_mode: bool,
  175. ):
  176. outputSize = (
  177. div_rtn(
  178. inputSize
  179. + pad_l
  180. + pad_r
  181. - dilation * (kernelSize - 1)
  182. - 1
  183. + (stride - 1 if ceil_mode else 0),
  184. stride,
  185. )
  186. + 1
  187. )
  188. if ceil_mode:
  189. if (outputSize - 1) * stride >= inputSize + pad_l:
  190. outputSize = outputSize - 1
  191. return outputSize
  192. def pooling_output_shape(
  193. inputSize: int,
  194. kernelSize: int,
  195. pad_l: int,
  196. stride: int,
  197. dilation: int,
  198. ceil_mode: bool,
  199. ):
  200. if stride == 0:
  201. raise AssertionError("stride should not be zero")
  202. return pooling_output_shape_pad_lr(
  203. inputSize, kernelSize, pad_l, pad_l, stride, dilation, ceil_mode
  204. )
  205. def pool2d_shape_check(
  206. input: list[int],
  207. kH: int,
  208. kW: int,
  209. dH: int,
  210. dW: int,
  211. padH: int,
  212. padW: int,
  213. dilationH: int,
  214. dilationW: int,
  215. nInputPlane: int,
  216. inputHeight: int,
  217. inputWidth: int,
  218. outputHeight: int,
  219. outputWidth: int,
  220. ):
  221. ndim = len(input)
  222. if not (kW > 0 and kH > 0):
  223. raise AssertionError(f"Expected kW ({kW}) > 0 and kH ({kH}) > 0")
  224. if not (dW > 0 and dH > 0):
  225. raise AssertionError(f"Expected dW ({dW}) > 0 and dH ({dH}) > 0")
  226. if not (dilationH > 0 and dilationW > 0):
  227. raise AssertionError(
  228. f"Expected dilationH ({dilationH}) > 0 and dilationW ({dilationW}) > 0"
  229. )
  230. valid_dims = input[1] != 0 and input[2] != 0
  231. if not (
  232. ndim == 3
  233. and input[0] != 0
  234. and valid_dims
  235. or (ndim == 4 and valid_dims and input[3] != 0)
  236. ):
  237. raise AssertionError(f"Invalid input dimensions: ndim={ndim}, input={input}")
  238. if not (kW // 2 >= padW and kH // 2 >= padH):
  239. raise AssertionError(
  240. f"Expected kW//2 ({kW // 2}) >= padW ({padW}) and "
  241. f"kH//2 ({kH // 2}) >= padH ({padH})"
  242. )
  243. if not (outputWidth >= 1 and outputHeight >= 1):
  244. raise AssertionError(
  245. f"Expected outputWidth ({outputWidth}) >= 1 and "
  246. f"outputHeight ({outputHeight}) >= 1"
  247. )
  248. def max_pool2d(
  249. input: list[int],
  250. kernel_size: list[int],
  251. stride: list[int],
  252. padding: list[int],
  253. dilation: list[int],
  254. ceil_mode: bool,
  255. ):
  256. if not (len(kernel_size) == 1 or len(kernel_size) == 2):
  257. raise AssertionError(
  258. "max_pool2d: kernel_size must either be a single int, or a tuple of two ints"
  259. )
  260. kH = kernel_size[0]
  261. kW = kH if len(kernel_size) == 1 else kernel_size[1]
  262. if not (len(stride) == 0 or len(stride) == 1 or len(stride) == 2):
  263. raise AssertionError(
  264. "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints"
  265. )
  266. dH = kH if len(stride) == 0 else stride[0]
  267. if len(stride) == 0:
  268. dW = kW
  269. elif len(stride) == 1:
  270. dW = dH
  271. else:
  272. dW = stride[1]
  273. if not (len(padding) == 1 or len(padding) == 2):
  274. raise AssertionError(
  275. "max_pool2d: padding must either be a single int, or a tuple of two ints"
  276. )
  277. padH = padding[0]
  278. padW = padH if len(padding) == 1 else padding[1]
  279. if not (len(dilation) == 1 or len(dilation) == 2):
  280. raise AssertionError(
  281. "max_pool2d: dilation must be either a single int, or a tuple of two ints"
  282. )
  283. dilationH = dilation[0]
  284. dilationW = dilationH if len(dilation) == 1 else dilation[1]
  285. if not (len(input) == 3 or len(input) == 4):
  286. raise AssertionError(f"Expected input length 3 or 4, but got {len(input)}")
  287. nbatch = input[-4] if len(input) == 4 else 1
  288. nInputPlane = input[-3]
  289. inputHeight = input[-2]
  290. inputWidth = input[-1]
  291. outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode)
  292. outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode)
  293. pool2d_shape_check(
  294. input,
  295. kH,
  296. kW,
  297. dH,
  298. dW,
  299. padH,
  300. padW,
  301. dilationH,
  302. dilationW,
  303. nInputPlane,
  304. inputHeight,
  305. inputWidth,
  306. outputHeight,
  307. outputWidth,
  308. )
  309. if len(input) == 3:
  310. return [nInputPlane, outputHeight, outputWidth]
  311. else:
  312. return [nbatch, nInputPlane, outputHeight, outputWidth]
  313. def max_pool2d_with_indices(
  314. input: list[int],
  315. kernel_size: list[int],
  316. stride: list[int],
  317. padding: list[int],
  318. dilation: list[int],
  319. ceil_mode: bool,
  320. ):
  321. out = max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
  322. return (out, out)
  323. def upsample_nearest2d(
  324. input: list[int],
  325. output_size: Optional[list[int]],
  326. scale_factors: Optional[list[float]],
  327. ):
  328. out: list[int] = []
  329. out.append(input[0])
  330. out.append(input[1])
  331. if scale_factors is None and output_size is None:
  332. raise AssertionError("Either output_size or scale_factors must be presented")
  333. if output_size is not None:
  334. if scale_factors is not None:
  335. raise AssertionError(
  336. "Must specify exactly one of output_size and scale_factors"
  337. )
  338. if len(output_size) != 2:
  339. raise AssertionError(
  340. f"Expected output_size to have length 2, but got {len(output_size)}"
  341. )
  342. out.append(output_size[0])
  343. out.append(output_size[1])
  344. if scale_factors is not None:
  345. if output_size is not None:
  346. raise AssertionError(
  347. "Must specify exactly one of output_size and scale_factors"
  348. )
  349. if len(scale_factors) != 2:
  350. raise AssertionError(
  351. f"Expected scale_factors to have length 2, but got {len(scale_factors)}"
  352. )
  353. out.append(int(input[2] * scale_factors[0]))
  354. out.append(int(input[3] * scale_factors[1]))
  355. return out
  356. def mm(self: list[int], mat2: list[int]):
  357. if len(self) != 2:
  358. raise AssertionError(f"self must be a matrix (got {len(self)} dimensions)")
  359. if len(mat2) != 2:
  360. raise AssertionError(f"mat2 must be a matrix (got {len(mat2)} dimensions)")
  361. if self[1] != mat2[0]:
  362. raise AssertionError(
  363. f"Matrix dimensions don't match for mm: self[1]={self[1]}, mat2[0]={mat2[0]}"
  364. )
  365. return [self[0], mat2[1]]
  366. def dot(self: list[int], tensor: list[int]):
  367. if not (len(self) == 1 and len(tensor) == 1):
  368. raise AssertionError(
  369. f"Expected 1D tensors for dot, got len(self)={len(self)}, "
  370. f"len(tensor)={len(tensor)}"
  371. )
  372. if self[0] != tensor[0]:
  373. raise AssertionError(
  374. f"Dot product dimension mismatch: self[0]={self[0]}, tensor[0]={tensor[0]}"
  375. )
  376. out: list[int] = []
  377. return out
  378. def mv(self: list[int], vec: list[int]):
  379. if not (len(self) == 2 and len(vec) == 1):
  380. raise AssertionError(
  381. f"Expected 2D matrix and 1D vector, got len(self)={len(self)}, "
  382. f"len(vec)={len(vec)}"
  383. )
  384. if self[1] != vec[0]:
  385. raise AssertionError(
  386. f"Matrix-vector dimension mismatch: self[1]={self[1]}, vec[0]={vec[0]}"
  387. )
  388. # TODO: return self
  389. return [self[0]]
  390. def unsqueeze(li: list[int], dim: int):
  391. dim = maybe_wrap_dim(dim, len(li) + 1)
  392. out = _copy(li)
  393. out.insert(dim, 1)
  394. return out
  395. def squeeze_nodim(li: list[int]):
  396. out: list[int] = []
  397. for i in range(len(li)):
  398. if li[i] != 1:
  399. out.append(li[i])
  400. return out
  401. def squeeze(li: list[int], dim: int):
  402. out: list[int] = []
  403. wrapped_dim = maybe_wrap_dim(dim, len(li))
  404. for i in range(len(li)):
  405. if i == wrapped_dim:
  406. if li[i] != 1:
  407. out.append(li[i])
  408. else:
  409. out.append(li[i])
  410. return out
  411. def squeeze_dims(li: list[int], dims: list[int]):
  412. if len(dims) == 0:
  413. return li
  414. wrapped_dims = _copy(dims)
  415. for i in range(len(dims)):
  416. wrapped_dims[i] = maybe_wrap_dim(wrapped_dims[i], len(li))
  417. result: list[int] = []
  418. for i in range(len(li)):
  419. if li[i] == 1:
  420. if i not in wrapped_dims:
  421. result.append(li[i])
  422. else:
  423. result.append(li[i])
  424. return result
  425. def index_select(self: list[int], dim: int, index: list[int]):
  426. dim = maybe_wrap_dim(dim, len(self))
  427. numel = multiply_integers(index)
  428. if len(index) > 1:
  429. raise AssertionError(f"Expected len(index) <= 1, but got {len(index)}")
  430. if not (dim == 0 or dim < len(self)):
  431. raise AssertionError(
  432. f"Expected dim ({dim}) == 0 or dim < len(self) ({len(self)})"
  433. )
  434. result_size: list[int] = []
  435. for i in range(len(self)):
  436. if dim == i:
  437. result_size.append(numel)
  438. else:
  439. result_size.append(self[i])
  440. return result_size
  441. def embedding(
  442. weight: list[int],
  443. indices: list[int],
  444. padding_idx: int = -1,
  445. scale_grad_by_freq: bool = False,
  446. sparse: bool = False,
  447. ):
  448. if len(weight) != 2:
  449. raise AssertionError(f"Expected weight to be 2D, but got {len(weight)}D")
  450. if len(indices) == 1:
  451. return index_select(weight, 0, indices)
  452. size = _copy(indices)
  453. size.append(weight[1])
  454. return size
  455. def max_int():
  456. return 9223372036854775807
  457. def slice(
  458. self: list[int], dim: int, start: Optional[int], end: Optional[int], step: int
  459. ):
  460. ndim = len(self)
  461. if ndim == 0:
  462. raise AssertionError("Cannot slice a 0-dimensional tensor")
  463. dim = maybe_wrap_dim(dim, ndim)
  464. start_val = start if start is not None else 0
  465. end_val = end if end is not None else max_int()
  466. if step <= 0:
  467. raise AssertionError(f"Expected step > 0, but got {step}")
  468. if start_val == max_int():
  469. start_val = 0
  470. if start_val < 0:
  471. start_val += self[dim]
  472. if end_val < 0:
  473. end_val += self[dim]
  474. if start_val < 0:
  475. start_val = 0
  476. elif start_val > self[dim]:
  477. start_val = self[dim]
  478. if end_val < start_val:
  479. end_val = start_val
  480. elif end_val >= self[dim]:
  481. end_val = self[dim]
  482. slice_len = end_val - start_val
  483. out = _copy(self)
  484. out[dim] = (slice_len + step - 1) // step
  485. return out
  486. def check_cat_no_zero_dim(tensors: list[list[int]]):
  487. for tensor in tensors:
  488. if len(tensor) <= 0:
  489. raise AssertionError("Cannot concatenate tensor with 0 dimensions")
  490. def legacy_cat_wrap_dim(dim: int, tensor_sizes: list[list[int]]):
  491. out_dim: Optional[int] = None
  492. for size in tensor_sizes:
  493. if not (len(size) == 1 and size[0] == 0):
  494. if out_dim is None:
  495. out_dim = maybe_wrap_dim(dim, len(size))
  496. if out_dim is None:
  497. out_dim = dim
  498. return out_dim
  499. def should_skip(tensor: list[int]):
  500. return numel(tensor) == 0 and len(tensor) == 1
  501. def check_cat_shape_except_dim(
  502. first: list[int], second: list[int], dimension: int, index: int
  503. ):
  504. first_dims = len(first)
  505. second_dims = len(second)
  506. if first_dims != second_dims:
  507. raise AssertionError(
  508. f"Tensors must have same number of dimensions, got {first_dims} and "
  509. f"{second_dims}"
  510. )
  511. for dim in range(0, first_dims):
  512. if dim != dimension:
  513. if first[dim] != second[dim]:
  514. raise AssertionError(
  515. f"Sizes of tensors must match except in dimension {dimension}, "
  516. f"got {first[dim]} and {second[dim]} at dimension {dim}"
  517. )
  518. def cat(tensors: list[list[int]], dim: int):
  519. check_cat_no_zero_dim(tensors)
  520. dim = legacy_cat_wrap_dim(dim, tensors)
  521. if len(tensors) <= 0:
  522. raise AssertionError("Cannot concatenate empty list of tensors")
  523. not_skipped_tensor: Optional[list[int]] = None
  524. for tensor in tensors:
  525. if not should_skip(tensor):
  526. not_skipped_tensor = tensor
  527. if not_skipped_tensor is None:
  528. return [0]
  529. cat_dim_size = 0
  530. for i in range(len(tensors)):
  531. tensor = tensors[i]
  532. if not should_skip(tensor):
  533. check_cat_shape_except_dim(not_skipped_tensor, tensor, dim, i)
  534. cat_dim_size = cat_dim_size + tensor[dim]
  535. result_size = _copy(not_skipped_tensor)
  536. result_size[dim] = cat_dim_size
  537. return result_size
  538. def stack(tensors: list[list[int]], dim: int):
  539. unsqueezed_tensors: list[list[int]] = []
  540. for tensor in tensors:
  541. unsqueezed = unsqueeze(tensor, dim)
  542. unsqueezed_tensors.append(unsqueezed)
  543. return cat(unsqueezed_tensors, dim)
  544. def select(self: list[int], dim: int, index: int):
  545. ndim = len(self)
  546. if ndim == 0:
  547. raise AssertionError("Cannot select from a 0-dimensional tensor")
  548. dim = maybe_wrap_dim(dim, ndim)
  549. size = self[dim]
  550. if index < -size or index >= size:
  551. raise AssertionError(
  552. f"Index {index} is out of bounds for dimension {dim} with size {size}"
  553. )
  554. if index < 0:
  555. index += size
  556. out: list[int] = []
  557. for i in range(ndim):
  558. if i != dim:
  559. out.append(self[i])
  560. return out
  561. def matmul(tensor1: list[int], tensor2: list[int]):
  562. dim_tensor1 = len(tensor1)
  563. dim_tensor2 = len(tensor2)
  564. if dim_tensor1 == 1 and dim_tensor2 == 1:
  565. return dot(tensor1, tensor2)
  566. elif dim_tensor1 == 2 and dim_tensor2 == 1:
  567. return mv(tensor1, tensor2)
  568. elif dim_tensor1 == 1 and dim_tensor2 == 2:
  569. return squeeze(mm(unsqueeze(tensor1, 0), tensor2), 0)
  570. elif dim_tensor1 == 2 and dim_tensor2 == 2:
  571. return mm(tensor1, tensor2)
  572. elif dim_tensor1 >= 1 and dim_tensor2 >= 1:
  573. # We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list);
  574. # we track m1 vs m2 separately even though they must match for nicer error messages
  575. n = tensor1[-2] if dim_tensor1 > 1 else 1
  576. batch_tensor1: list[int] = []
  577. # TODO: handling of slice
  578. for i in range(dim_tensor1 - 2):
  579. batch_tensor1.append(tensor1[i])
  580. p = tensor2[-1]
  581. batch_tensor2: list[int] = []
  582. # TODO: handling of slice
  583. for i in range(dim_tensor2 - 2):
  584. batch_tensor2.append(tensor2[i])
  585. # expand the batch portion (i.e. cut off matrix dimensions and expand rest)
  586. expand_batch_portion = broadcast(batch_tensor1, batch_tensor2)
  587. # todo: copy ?
  588. output_shape = expand_batch_portion
  589. if dim_tensor1 > 1:
  590. output_shape.append(n)
  591. if dim_tensor2 > 1:
  592. output_shape.append(p)
  593. return output_shape
  594. else:
  595. raise AssertionError("both arguments to matmul need to be at least 1D")
  596. def t(self: list[int]):
  597. if len(self) > 2:
  598. raise AssertionError(
  599. f"Expected tensor to have <= 2 dimensions, but got {len(self)}"
  600. )
  601. self_len = len(self)
  602. if self_len == 0:
  603. out: list[int] = []
  604. return out
  605. elif self_len == 1:
  606. return [self[0]]
  607. else:
  608. return [self[1], self[0]]
  609. def transpose(self: list[int], dim0: int, dim1: int):
  610. ndims = len(self)
  611. dim0 = maybe_wrap_dim(dim0, ndims)
  612. dim1 = maybe_wrap_dim(dim1, ndims)
  613. if dim0 == dim1:
  614. return _copy(self)
  615. out: list[int] = []
  616. for i in range(ndims):
  617. if i == dim0:
  618. out.append(self[dim1])
  619. elif i == dim1:
  620. out.append(self[dim0])
  621. else:
  622. out.append(self[i])
  623. return out
  624. def linear(input: list[int], weight: list[int], bias: Optional[list[int]]):
  625. out = matmul(input, t(weight))
  626. if bias is not None:
  627. if broadcast(bias, out) != out:
  628. raise AssertionError(
  629. f"Bias shape {bias} is not broadcastable to output shape {out}"
  630. )
  631. return out
  632. def addmm(self: list[int], mat1: list[int], mat2: list[int], beta: Any, alpha: Any):
  633. return broadcast(self, mm(mat1, mat2))
  634. def check_non_negative(array: list[int]) -> bool:
  635. # TODO: look into rewriting with early return and getting loop unrolling to fire
  636. non_negative = False
  637. for val in array:
  638. if val < 0:
  639. non_negative = True
  640. return non_negative
  641. def check_shape_forward(
  642. input: list[int],
  643. weight_sizes: list[int],
  644. bias: Optional[list[int]],
  645. stride: list[int],
  646. padding: list[int],
  647. dilation: list[int],
  648. groups: int,
  649. ):
  650. k = len(input)
  651. weight_dim = len(weight_sizes)
  652. # TODO: assertions could be expanded with the error messages
  653. if check_non_negative(padding):
  654. raise AssertionError(f"Padding must be non-negative, got {padding}")
  655. if check_non_negative(stride):
  656. raise AssertionError(f"Stride must be non-negative, got {stride}")
  657. if weight_dim != k:
  658. raise AssertionError(f"Expected weight_dim ({weight_dim}) == k ({k})")
  659. if weight_sizes[0] < groups:
  660. raise AssertionError(
  661. f"Expected weight_sizes[0] ({weight_sizes[0]}) >= groups ({groups})"
  662. )
  663. if (weight_sizes[0] % groups) != 0:
  664. raise AssertionError(
  665. f"Expected weight_sizes[0] ({weight_sizes[0]}) to be divisible by "
  666. f"groups ({groups})"
  667. )
  668. # only handling not transposed
  669. if input[1] != weight_sizes[1] * groups:
  670. raise AssertionError(
  671. f"Expected input[1] ({input[1]}) == weight_sizes[1] * groups "
  672. f"({weight_sizes[1] * groups})"
  673. )
  674. if bias is not None and not (len(bias) == 1 and bias[0] == weight_sizes[0]):
  675. raise AssertionError(
  676. f"Expected bias to be None or have shape [1] with value "
  677. f"weight_sizes[0]={weight_sizes[0]}, got {bias}"
  678. )
  679. for i in range(2, k):
  680. if (input[i] + 2 * padding[i - 2]) < (
  681. dilation[i - 2] * (weight_sizes[i] - 1) + 1
  682. ):
  683. raise AssertionError(
  684. f"Calculated padded input size ({input[i] + 2 * padding[i - 2]}) "
  685. f"is smaller than effective kernel size "
  686. f"({dilation[i - 2] * (weight_sizes[i] - 1) + 1}) at dimension {i}"
  687. )
  688. # this is not handling transposed convolution yet
  689. def conv_output_size(
  690. input_size: list[int],
  691. weight_size: list[int],
  692. bias: Optional[list[int]],
  693. stride: list[int],
  694. padding: list[int],
  695. dilation: list[int],
  696. groups: int,
  697. ):
  698. check_shape_forward(
  699. input_size, weight_size, bias, stride, padding, dilation, groups
  700. )
  701. has_dilation = len(dilation) > 0
  702. dim = len(input_size)
  703. output_size: list[int] = []
  704. input_batch_size_dim = 0
  705. weight_output_channels_dim = 0
  706. output_size.append(input_size[input_batch_size_dim])
  707. output_size.append(weight_size[weight_output_channels_dim])
  708. for d in range(2, dim):
  709. dilation_ = dilation[d - 2] if has_dilation else 1
  710. kernel = dilation_ * (weight_size[d] - 1) + 1
  711. output_size.append(
  712. (input_size[d] + (2 * padding[d - 2]) - kernel) // stride[d - 2] + 1
  713. )
  714. return output_size
  715. def conv1d(
  716. input: list[int],
  717. weight: list[int],
  718. bias: Optional[list[int]],
  719. stride: list[int],
  720. padding: list[int],
  721. dilation: list[int],
  722. groups: int,
  723. ):
  724. if len(weight) != 3:
  725. raise AssertionError(f"Expected 3D weight for conv1d, got {len(weight)}D")
  726. if len(input) != 3:
  727. raise AssertionError(f"Expected 3D input for conv1d, got {len(input)}D")
  728. return conv_output_size(input, weight, bias, stride, padding, dilation, groups)
  729. def conv2d(
  730. input: list[int],
  731. weight: list[int],
  732. bias: Optional[list[int]],
  733. stride: list[int],
  734. padding: list[int],
  735. dilation: list[int],
  736. groups: int,
  737. ):
  738. if len(weight) != 4:
  739. raise AssertionError(f"Expected 4D weight for conv2d, got {len(weight)}D")
  740. if len(input) != 4:
  741. raise AssertionError(f"Expected 4D input for conv2d, got {len(input)}D")
  742. return conv_output_size(input, weight, bias, stride, padding, dilation, groups)
  743. def conv_backwards(
  744. grad_output: list[int],
  745. input: list[int],
  746. weight: list[int],
  747. biases: Optional[list[int]],
  748. ):
  749. # Bias gradient is always generated regardess of if biases is supplied
  750. return _copy(input), _copy(weight), [grad_output[1]]
  751. def conv_transpose2d_input(
  752. input: list[int],
  753. weight: list[int],
  754. bias: Optional[list[int]] = None,
  755. stride: Optional[list[int]] = None,
  756. padding: Optional[list[int]] = None,
  757. output_padding: Optional[list[int]] = None,
  758. groups: int = 1,
  759. dilation: Optional[list[int]] = None,
  760. ) -> list[int]:
  761. if stride is None:
  762. stride = [1, 1]
  763. if padding is None:
  764. padding = [0, 0]
  765. if output_padding is None:
  766. output_padding = [0, 0]
  767. if dilation is None:
  768. dilation = [1, 1]
  769. has_dilation = len(dilation) > 0
  770. dim = len(input)
  771. output_size: list[int] = []
  772. input_batch_size_dim = 0
  773. weight_output_channels_dim = 1
  774. output_size.append(input[input_batch_size_dim])
  775. output_size.append(weight[weight_output_channels_dim] * groups)
  776. for d in range(2, dim):
  777. dilation_ = dilation[d - 2] if has_dilation else 1
  778. kernel = dilation_ * (weight[d] - 1)
  779. output_size.append(
  780. (input[d] - 1) * stride[d - 2]
  781. - 2 * padding[d - 2]
  782. + kernel
  783. + output_padding[d - 2]
  784. + 1
  785. )
  786. return output_size
  787. def conv_forwards(
  788. input: list[int],
  789. weight: list[int],
  790. bias: Optional[list[int]],
  791. stride: list[int],
  792. padding: list[int],
  793. dilation: list[int],
  794. transposed: bool,
  795. output_padding: list[int],
  796. groups: int,
  797. ) -> list[int]:
  798. has_dilation = len(dilation) > 0
  799. has_output_padding = len(output_padding) > 0
  800. dim = len(input)
  801. output_size: list[int] = []
  802. input_batch_size_dim = 0
  803. weight_output_channels_dim = 1 if transposed else 0
  804. output_size.append(input[input_batch_size_dim])
  805. if transposed:
  806. output_size.append(weight[weight_output_channels_dim] * groups)
  807. else:
  808. output_size.append(weight[weight_output_channels_dim])
  809. for d in range(2, dim):
  810. dilation_ = dilation[d - 2] if has_dilation else 1
  811. output_padding_ = output_padding[d - 2] if has_output_padding else 0
  812. if transposed:
  813. kernel = dilation_ * (weight[d] - 1)
  814. output_size.append(
  815. (input[d] - 1) * stride[d - 2]
  816. - 2 * padding[d - 2]
  817. + kernel
  818. + output_padding_
  819. + 1
  820. )
  821. else:
  822. kernel = dilation_ * (weight[d] - 1) + 1
  823. output_size.append(
  824. (input[d] + (2 * padding[d - 2]) - kernel) // stride[d - 2] + 1
  825. )
  826. return output_size
  827. def _conv_forwards(
  828. input: list[int],
  829. weight: list[int],
  830. bias: Optional[list[int]],
  831. stride: list[int],
  832. padding: list[int],
  833. dilation: list[int],
  834. transposed: bool,
  835. output_padding: list[int],
  836. groups: int,
  837. benchmark: bool,
  838. deterministic: bool,
  839. cudnn_enabled: bool,
  840. allow_tf32: bool,
  841. ) -> list[int]:
  842. return conv_forwards(
  843. input,
  844. weight,
  845. bias,
  846. stride,
  847. padding,
  848. dilation,
  849. transposed,
  850. output_padding,
  851. groups,
  852. )
  853. def batch_norm(
  854. input: list[int],
  855. weight: Optional[list[int]],
  856. bias: Optional[list[int]],
  857. running_mean: Optional[list[int]],
  858. running_var: Optional[list[int]],
  859. training: bool,
  860. momentum: float,
  861. eps: float,
  862. cudnn_enabled: bool,
  863. ):
  864. out: list[int] = []
  865. for elem in input:
  866. out.append(elem)
  867. return out
  868. def conv3d(
  869. input: list[int],
  870. weight: list[int],
  871. bias: Optional[list[int]],
  872. stride: list[int],
  873. padding: list[int],
  874. dilation: list[int],
  875. groups: int,
  876. ):
  877. if len(weight) != 5:
  878. raise AssertionError(f"Expected 5D weight for conv3d, got {len(weight)}D")
  879. if len(input) != 5:
  880. raise AssertionError(f"Expected 5D input for conv3d, got {len(input)}D")
  881. return conv_output_size(input, weight, bias, stride, padding, dilation, groups)
  882. def maybe_wrap_dim(dim: int, dim_post_expr: int, wrap_scalar: bool = True):
  883. if dim_post_expr <= 0:
  884. if not wrap_scalar:
  885. raise AssertionError(
  886. "Expected wrap_scalar to be True when dim_post_expr <= 0"
  887. )
  888. dim_post_expr = 1
  889. min = -dim_post_expr
  890. max = dim_post_expr - 1
  891. if dim < min or dim > max:
  892. raise AssertionError(
  893. f"Dimension {dim} out of range (expected to be in range [{min}, {max}])"
  894. )
  895. if dim < 0:
  896. dim += dim_post_expr
  897. return dim
  898. def zero_dim_tensor(input: Any):
  899. out: list[int] = []
  900. return out
  901. def multiply_integers(li: list[int]):
  902. out = 1
  903. for elem in li:
  904. out = out * elem
  905. return out
  906. def arange_end(end: number, inp0: Any, inp1: Any, inp2: Any, inp3: Any):
  907. if end < 0:
  908. raise AssertionError(f"Expected end ({end}) >= 0")
  909. return [int(math.ceil(end))]
  910. def arange_start(
  911. start: number, end: number, inp0: Any, inp1: Any, inp2: Any, inp3: Any
  912. ):
  913. if end < 0:
  914. raise AssertionError(f"Expected end ({end}) >= 0")
  915. if end < start:
  916. raise AssertionError(f"Expected end ({end}) >= start ({start})")
  917. return [int(math.ceil(end - start))]
  918. def arange_start_step(
  919. start: number, end: number, step: number, inp0: Any, inp1: Any, inp2: Any, inp3: Any
  920. ):
  921. if step == 0:
  922. raise AssertionError("step must not be zero")
  923. if step < 0:
  924. if start < end:
  925. raise AssertionError(
  926. f"Expected start ({start}) >= end ({end}) when step < 0"
  927. )
  928. else:
  929. if end < start:
  930. raise AssertionError(
  931. f"Expected end ({end}) >= start ({start}) when step > 0"
  932. )
  933. return [int(math.ceil((end - start) / step))]
  934. def permute(input: list[int], dims: list[int]):
  935. if len(input) != len(dims):
  936. raise AssertionError(
  937. f"Expected len(input) ({len(input)}) == len(dims) ({len(dims)})"
  938. )
  939. ndim = len(dims)
  940. seen_dims: list[int] = []
  941. newSizes: list[int] = []
  942. for i in range(ndim):
  943. dim = maybe_wrap_dim(dims[i], ndim)
  944. seen_dims.append(dim)
  945. newSizes.append(input[dim])
  946. for i in range(1, ndim):
  947. for j in range(i):
  948. if seen_dims[i] == seen_dims[j]:
  949. raise AssertionError(
  950. f"Repeated dimension {seen_dims[i]} in permute dimensions"
  951. )
  952. return newSizes
  953. def movedim(self: list[int], source: list[int], destination: list[int]) -> list[int]:
  954. self_dim = len(self)
  955. if self_dim <= 1:
  956. return self
  957. normalized_src: list[int] = []
  958. normalized_dst: list[int] = []
  959. for i in range(len(source)):
  960. normalized_src.append(maybe_wrap_dim(source[i], self_dim))
  961. normalized_dst.append(maybe_wrap_dim(destination[i], self_dim))
  962. order = [-1 for i in range(self_dim)]
  963. src_dims = [i for i in range(self_dim)]
  964. dst_dims = [i for i in range(self_dim)]
  965. for i in range(len(source)):
  966. order[normalized_dst[i]] = normalized_src[i]
  967. src_dims[normalized_src[i]] = -1
  968. dst_dims[normalized_dst[i]] = -1
  969. source_dims: list[int] = []
  970. destination_dims: list[int] = []
  971. for ele in src_dims:
  972. if ele != -1:
  973. source_dims.append(ele)
  974. for ele in dst_dims:
  975. if ele != -1:
  976. destination_dims.append(ele)
  977. rest_dim = self_dim - len(source)
  978. for i in range(rest_dim):
  979. order[destination_dims[i]] = source_dims[i]
  980. return permute(self, order)
  981. def flatten(input: list[int], start_dim: int, end_dim: int):
  982. start_dim = maybe_wrap_dim(start_dim, len(input))
  983. end_dim = maybe_wrap_dim(end_dim, len(input))
  984. if start_dim > end_dim:
  985. raise AssertionError(f"Expected start_dim ({start_dim}) <= end_dim ({end_dim})")
  986. if len(input) == 0:
  987. return [1]
  988. if start_dim == end_dim:
  989. # TODO: return self
  990. out: list[int] = []
  991. for elem in input:
  992. out.append(elem)
  993. return out
  994. slice_numel = 1
  995. for i in range(start_dim, end_dim + 1):
  996. slice_numel *= input[i]
  997. # TODO: use slicing when slice optimization has landed
  998. # slice_numel = multiply_integers(input[start_dim:end_dim - start_dim + 1])
  999. shape: list[int] = []
  1000. for i in range(start_dim):
  1001. shape.append(input[i])
  1002. shape.append(slice_numel)
  1003. for i in range(end_dim + 1, len(input)):
  1004. shape.append(input[i])
  1005. return shape
  1006. def nonzero_lower_bound(input: list[int]):
  1007. return [0, len(input)]
  1008. def nonzero_upper_bound(input: list[int]):
  1009. return [numel(input), len(input)]
  1010. def _reduce_along_dim(self: list[int], dim: int, keepdim: bool):
  1011. dim = maybe_wrap_dim(dim, len(self))
  1012. out: list[int] = []
  1013. for i, self_dim in enumerate(self):
  1014. if i == dim:
  1015. if keepdim:
  1016. out.append(1)
  1017. else:
  1018. out.append(self_dim)
  1019. return out
  1020. def argmax(
  1021. self: list[int], dim: Optional[int] = None, keepdim: bool = False
  1022. ) -> list[int]:
  1023. if dim is None:
  1024. return []
  1025. return _reduce_along_dim(self, dim, keepdim)
  1026. def bmm(self: list[int], mat2: list[int]) -> list[int]:
  1027. if len(self) != 3:
  1028. raise AssertionError(f"bmm only supports 3D tensors, got {len(self)}D")
  1029. if len(mat2) != 3:
  1030. raise AssertionError(f"bmm only supports 3D tensors, got {len(mat2)}D")
  1031. if self[0] != mat2[0]:
  1032. raise AssertionError(
  1033. f"mismatching batch dimension: self[0]={self[0]}, mat2[0]={mat2[0]}"
  1034. )
  1035. if self[2] != mat2[1]:
  1036. raise AssertionError(
  1037. f"mismatching contracting dimension: self[2]={self[2]}, mat2[1]={mat2[1]}"
  1038. )
  1039. return [self[0], self[1], mat2[2]]
  1040. def _shape_as_tensor(self: list[int]) -> list[int]:
  1041. return [len(self)]
  1042. def topk(self: list[int], k: int, dim: int = -1) -> tuple[list[int], list[int]]:
  1043. if len(self) == 0:
  1044. result: list[int] = []
  1045. else:
  1046. if k > self[dim]:
  1047. raise AssertionError(
  1048. f"k ({k}) is too big for dimension {dim} of size {self[dim]}"
  1049. )
  1050. result = _copy(self)
  1051. result[dim] = k
  1052. return result, result
  1053. def nll_loss_forward(
  1054. self: list[int], target: list[int], weight: Optional[list[int]], reduction: int
  1055. ) -> tuple[list[int], list[int]]:
  1056. # This is taken shamelessly from the meta function in LossNLL.cpp
  1057. self_dim = len(self)
  1058. target_dim = len(target)
  1059. if not (0 < self_dim <= 2):
  1060. raise AssertionError(f"Expected 0 < self_dim <= 2, but got self_dim={self_dim}")
  1061. if target_dim > 1:
  1062. raise AssertionError(f"Expected target_dim <= 1, but got {target_dim}")
  1063. no_batch_dim = self_dim == 1 and target_dim == 0
  1064. if not (no_batch_dim or (self[0] == target[0])):
  1065. raise AssertionError(
  1066. f"Batch size mismatch: self[0]={self[0]}, target[0]={target[0]}"
  1067. )
  1068. n_classes = self[-1]
  1069. scalar_shape: list[int] = []
  1070. if weight is not None and not (len(weight) == 1 and weight[0] == n_classes):
  1071. raise AssertionError(
  1072. f"Expected weight to be None or have shape [n_classes], "
  1073. f"got {weight} with n_classes={n_classes}"
  1074. )
  1075. if reduction == 0 and self_dim == 2:
  1076. reduction_shape = [self[0]]
  1077. else:
  1078. reduction_shape = scalar_shape
  1079. return reduction_shape, scalar_shape
  1080. def native_layer_norm(
  1081. input: list[int], normalized_shape: list[int]
  1082. ) -> tuple[list[int], list[int], list[int]]:
  1083. reduction_shape: list[int] = []
  1084. num_unreduced_dimensions = len(input) - len(normalized_shape)
  1085. if num_unreduced_dimensions < 0:
  1086. raise AssertionError(
  1087. f"Expected len(input) ({len(input)}) >= len(normalized_shape) "
  1088. f"({len(normalized_shape)})"
  1089. )
  1090. for i in range(num_unreduced_dimensions):
  1091. reduction_shape.append(input[i])
  1092. for i in range(num_unreduced_dimensions, len(input)):
  1093. reduction_shape.append(1)
  1094. return _copy(input), reduction_shape, reduction_shape
  1095. def native_batch_norm(
  1096. input: list[int],
  1097. weight: Optional[list[int]],
  1098. bias: Optional[list[int]],
  1099. running_mean: Optional[list[int]],
  1100. running_var: Optional[list[int]],
  1101. training: bool,
  1102. ) -> tuple[list[int], list[int], list[int]]:
  1103. if training:
  1104. _size = [input[1]]
  1105. else:
  1106. _size = [0]
  1107. return _copy(input), _size, _size
  1108. def _batch_norm_with_update(
  1109. input: list[int],
  1110. weight: Optional[list[int]],
  1111. bias: Optional[list[int]],
  1112. running_mean: Optional[list[int]],
  1113. running_var: Optional[list[int]],
  1114. ) -> tuple[list[int], list[int], list[int], list[int]]:
  1115. _size = [input[1]]
  1116. return _copy(input), _size, _size, [0]
  1117. def cross_entropy_loss(
  1118. self: list[int],
  1119. target: list[int],
  1120. weight: Optional[list[int]] = None,
  1121. reduction: int = 1,
  1122. ignore_index: int = -100,
  1123. label_smoothing: float = 0.0,
  1124. ) -> list[int]:
  1125. result_shape = nll_loss_forward(self, target, weight, reduction)[0]
  1126. return result_shape
  1127. """
  1128. Currently deferring the enabling of this, as part of the propoasal to suspend
  1129. adding ops.
  1130. There are currently cases in the test case where this is being called
  1131. in the SSA opinfo tests with with unexpected values (eg list of two ints, see the first
  1132. opinfo test). The behavior of index is significantly dependent on the inputs.
  1133. This could be an error with how we are matching up shape functions, or that this
  1134. function needs to just implement everything.
  1135. def index_Tensor(self: List[int], indices: List[Optional[List[int]]]) -> List[int]:
  1136. assert len(indices) <= len(self), "More indices than dimensions to index"
  1137. broadcasted_shape: List[int] = []
  1138. for index_tensor_shape in indices:
  1139. if index_tensor_shape is not None:
  1140. broadcasted_shape = broadcast(broadcasted_shape, index_tensor_shape)
  1141. return broadcasted_shape
  1142. """
  1143. ScriptFn = torch._C.ScriptFunction
  1144. shape_compute_graph_mapping: dict[str, ScriptFn] = {}
  1145. bounded_compute_graph_mapping: dict[str, tuple[ScriptFn, ScriptFn]] = {}
  1146. script_func_map: dict[Callable, ScriptFn] = {}
  1147. def process_func(func: Callable):
  1148. if func not in script_func_map:
  1149. scripted_func = torch.jit.script(func)
  1150. torch._C._jit_pass_inline(scripted_func.graph)
  1151. for _ in range(2):
  1152. torch._C._jit_pass_peephole(scripted_func.graph)
  1153. torch._C._jit_pass_constant_propagation(scripted_func.graph)
  1154. script_func_map[func] = scripted_func
  1155. return script_func_map[func]
  1156. def add_shape_compute_mapping(operator_schema: str, func: Callable):
  1157. global shape_compute_graph_mapping
  1158. shape_compute_graph_mapping[operator_schema] = process_func(func)
  1159. def add_bounded_compute_mapping(
  1160. operator_schema: str, lower_bound_func: Callable, upper_bound_func: Callable
  1161. ):
  1162. # Adds a shape compute function for both upper and lower bounds
  1163. fns = (process_func(lower_bound_func), process_func(upper_bound_func))
  1164. bounded_compute_graph_mapping[operator_schema] = fns
  1165. add_shape_compute_mapping(
  1166. "aten::contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a)",
  1167. unary,
  1168. )
  1169. add_shape_compute_mapping(
  1170. "aten::rsub.Tensor(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", unary
  1171. )
  1172. add_shape_compute_mapping(
  1173. "aten::dropout(Tensor input, float p, bool train) -> Tensor", unary
  1174. )
  1175. add_shape_compute_mapping(
  1176. "aten::adaptive_avg_pool2d(Tensor self, int[2] output_size) -> Tensor",
  1177. adaptive_avg_pool2d,
  1178. )
  1179. add_shape_compute_mapping(
  1180. "prim::NumToTensor.Scalar(Scalar a) -> Tensor", zero_dim_tensor
  1181. )
  1182. add_shape_compute_mapping("prim::NumToTensor.bool(bool a) -> Tensor", zero_dim_tensor)
  1183. add_shape_compute_mapping(
  1184. "aten::zeros(int[] size, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)",
  1185. unary,
  1186. )
  1187. add_shape_compute_mapping(
  1188. "aten::to.dtype(Tensor(a) self, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))",
  1189. unary,
  1190. )
  1191. add_shape_compute_mapping(
  1192. "aten::arange(Scalar end, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)",
  1193. arange_end,
  1194. )
  1195. add_shape_compute_mapping(
  1196. "aten::arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor",
  1197. arange_start,
  1198. )
  1199. add_shape_compute_mapping(
  1200. "aten::arange.start_step(Scalar start, Scalar end, Scalar step, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor",
  1201. arange_start_step,
  1202. )
  1203. add_shape_compute_mapping("aten::squeeze(Tensor(a) self) -> Tensor(a)", squeeze_nodim)
  1204. add_shape_compute_mapping(
  1205. "aten::squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)", squeeze
  1206. )
  1207. add_shape_compute_mapping(
  1208. "aten::squeeze.dims(Tensor(a) self, int[] dim) -> Tensor(a)", squeeze_dims
  1209. )
  1210. add_shape_compute_mapping(
  1211. "aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a)", unsqueeze
  1212. )
  1213. add_shape_compute_mapping(
  1214. "aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a)",
  1215. slice,
  1216. )
  1217. add_shape_compute_mapping(
  1218. "aten::select.int(Tensor(a) self, int dim, int index) -> Tensor(a)", select
  1219. )
  1220. add_shape_compute_mapping(
  1221. "aten::index_select(Tensor self, int dim, Tensor index) -> Tensor", index_select
  1222. )
  1223. add_shape_compute_mapping(
  1224. "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, "
  1225. "float eps=1e-05, bool cudnn_enable=True) -> Tensor",
  1226. unary,
  1227. )
  1228. add_shape_compute_mapping(
  1229. "aten::softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor", unary
  1230. )
  1231. add_shape_compute_mapping(
  1232. "aten::_no_grad_embedding_renorm_(Tensor weight, Tensor input, float max_norm, float norm_type) -> Tensor",
  1233. unary,
  1234. )
  1235. add_shape_compute_mapping(
  1236. "aten::embedding_renorm_(Tensor(a!) self, Tensor indices, float max_norm, float norm_type) -> Tensor(a!)",
  1237. unary,
  1238. )
  1239. add_shape_compute_mapping(
  1240. "aten::embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor",
  1241. embedding,
  1242. )
  1243. add_shape_compute_mapping("aten::mm(Tensor self, Tensor mat2) -> Tensor", mm)
  1244. add_shape_compute_mapping("aten::dot(Tensor self, Tensor tensor) -> Tensor", dot)
  1245. add_shape_compute_mapping("aten::mv(Tensor self, Tensor vec) -> Tensor", mv)
  1246. add_shape_compute_mapping("aten::matmul(Tensor self, Tensor other) -> Tensor", matmul)
  1247. add_shape_compute_mapping(
  1248. "aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor", linear
  1249. )
  1250. add_shape_compute_mapping(
  1251. "aten::max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor",
  1252. max_pool2d,
  1253. )
  1254. add_shape_compute_mapping(
  1255. "aten::max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)",
  1256. max_pool2d_with_indices,
  1257. )
  1258. add_shape_compute_mapping("aten::t(Tensor(a) self) -> Tensor(a)", t)
  1259. add_shape_compute_mapping(
  1260. "aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)", transpose
  1261. )
  1262. add_shape_compute_mapping(
  1263. "aten::conv1d(Tensor input, Tensor weight, Tensor? bias=None, int[1] stride=1, int[1] padding=0, int[1] dilation=1, int groups=1) -> Tensor",
  1264. conv1d,
  1265. )
  1266. add_shape_compute_mapping(
  1267. "aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor",
  1268. conv2d,
  1269. )
  1270. add_shape_compute_mapping(
  1271. "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor",
  1272. batch_norm,
  1273. )
  1274. add_shape_compute_mapping(
  1275. "aten::conv3d(Tensor input, Tensor weight, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] dilation=1, int groups=1) -> Tensor",
  1276. conv3d,
  1277. )
  1278. add_shape_compute_mapping(
  1279. "aten::convolution_backward(Tensor grad_output, Tensor input, Tensor weight, int[]? bias_sizes, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)",
  1280. conv_backwards,
  1281. )
  1282. add_shape_compute_mapping(
  1283. "aten::convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor",
  1284. conv_forwards,
  1285. )
  1286. add_shape_compute_mapping(
  1287. "aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor",
  1288. _conv_forwards,
  1289. )
  1290. add_shape_compute_mapping(
  1291. "aten::conv_transpose2d.input(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] output_padding=0, int groups=1, int[2] dilation=1) -> Tensor",
  1292. conv_transpose2d_input,
  1293. )
  1294. add_shape_compute_mapping(
  1295. "aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)",
  1296. flatten,
  1297. )
  1298. add_shape_compute_mapping("aten::cat(Tensor[] tensors, int dim=0) -> Tensor", cat)
  1299. add_shape_compute_mapping("aten::stack(Tensor[] tensors, int dim=0) -> Tensor", stack)
  1300. add_shape_compute_mapping(
  1301. "aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)", permute
  1302. )
  1303. add_shape_compute_mapping(
  1304. "aten::movedim.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a)",
  1305. movedim,
  1306. )
  1307. add_shape_compute_mapping("aten::view(Tensor(a) self, int[] size) -> Tensor(a)", view)
  1308. add_shape_compute_mapping(
  1309. "aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)", expand
  1310. )
  1311. add_shape_compute_mapping(
  1312. "aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> Tensor(a)",
  1313. expand_one_unused,
  1314. )
  1315. add_shape_compute_mapping(
  1316. "aten::mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor",
  1317. sum_mean_dim,
  1318. )
  1319. add_shape_compute_mapping(
  1320. "aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor",
  1321. sum_mean_dim,
  1322. )
  1323. add_shape_compute_mapping(
  1324. "aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)",
  1325. max_dim,
  1326. )
  1327. add_shape_compute_mapping(
  1328. "aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor", zero_dim_tensor
  1329. )
  1330. add_shape_compute_mapping(
  1331. "aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor", zero_dim_tensor
  1332. )
  1333. add_shape_compute_mapping(
  1334. "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor",
  1335. addmm,
  1336. )
  1337. add_shape_compute_mapping(
  1338. "aten::upsample_nearest2d.vec(Tensor input, int[]? output_size, float[]? scale_factors) -> (Tensor)",
  1339. upsample_nearest2d,
  1340. )
  1341. add_shape_compute_mapping(
  1342. "aten::quantize_per_tensor(Tensor self, float scale, int zero_point, ScalarType dtype) -> Tensor",
  1343. unary,
  1344. )
  1345. add_shape_compute_mapping(
  1346. "aten::quantize_per_tensor.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, ScalarType dtype) -> Tensor",
  1347. unary,
  1348. )
  1349. add_shape_compute_mapping("aten::dequantize(Tensor self) -> Tensor", unary)
  1350. add_shape_compute_mapping(
  1351. "quantized::add(Tensor qa, Tensor qb, float scale, int zero_point) -> Tensor qc",
  1352. broadcast,
  1353. )
  1354. add_shape_compute_mapping(
  1355. "aten::argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor", argmax
  1356. )
  1357. add_shape_compute_mapping("aten::bmm(Tensor self, Tensor mat2) -> Tensor", bmm)
  1358. add_shape_compute_mapping(
  1359. "aten::_shape_as_tensor(Tensor self) -> Tensor", _shape_as_tensor
  1360. )
  1361. add_shape_compute_mapping(
  1362. "aten::topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)",
  1363. topk,
  1364. )
  1365. add_shape_compute_mapping(
  1366. "aten::nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> (Tensor output, Tensor total_weight)",
  1367. nll_loss_forward,
  1368. )
  1369. add_shape_compute_mapping(
  1370. "aten::native_layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)",
  1371. native_layer_norm,
  1372. )
  1373. add_shape_compute_mapping(
  1374. "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)",
  1375. native_batch_norm,
  1376. )
  1377. add_shape_compute_mapping(
  1378. "aten::_native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)",
  1379. native_batch_norm,
  1380. )
  1381. add_shape_compute_mapping(
  1382. "aten::_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)",
  1383. native_batch_norm,
  1384. )
  1385. add_shape_compute_mapping(
  1386. "_batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)",
  1387. _batch_norm_with_update,
  1388. )
  1389. add_shape_compute_mapping(
  1390. "aten::cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor",
  1391. cross_entropy_loss,
  1392. )
  1393. # add_shape_compute_mapping("aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor", index_Tensor)
  1394. # TODO: migrate over all of symbolic_shape_registry_util.cpp
  1395. # These are duplicated here so that the functions will be serialized
  1396. add_shape_compute_mapping(
  1397. "aten::lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor",
  1398. broadcast_three,
  1399. )
  1400. add_shape_compute_mapping(
  1401. "aten::where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor",
  1402. broadcast_one_three,
  1403. )
  1404. add_shape_compute_mapping(
  1405. "aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)",
  1406. broadcast_inplace,
  1407. )
  1408. # quantized_conv_prepack TODO
  1409. # Shape Compute Fn with upper and lower bounds
  1410. add_bounded_compute_mapping(
  1411. "aten::nonzero(Tensor self) -> (Tensor)", nonzero_lower_bound, nonzero_upper_bound
  1412. )