| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633 |
- # mypy: allow-untyped-defs
- import math
- from collections.abc import Callable
- from typing import Any, Optional, Union
- number = Union[int, float]
- # flake8: noqa
- ###
- # There are generated files that depend on this file
- # To re-generate, please run from the root of the repo:
- # python torchgen/shape_functions/gen_jit_shape_functions.py
- # How to test:
- # After regenerating files, compile PyTorch.
- # Then run: ./build/bin/test_jit --gtest_filter=TestShapeGraphLinting.Basic
- # If you have enabled opinfo testing for the op, also run:
- # python test/test_ops_jit.py TestJitCPU.test_variant_consistency_jit_[FAILING_OP]_cpu_float32
- # to reproduce errors from opinfo tests.
- # Example PR: https://github.com/pytorch/pytorch/pull/80860/files
- ####
- import torch
- def broadcast(a: list[int], b: list[int]):
- dimsA = len(a)
- dimsB = len(b)
- ndim = max(dimsA, dimsB)
- expandedSizes: list[int] = []
- for i in range(ndim):
- offset = ndim - 1 - i
- dimA = dimsA - 1 - offset
- dimB = dimsB - 1 - offset
- sizeA = a[dimA] if (dimA >= 0) else 1
- sizeB = b[dimB] if (dimB >= 0) else 1
- if sizeA != sizeB and sizeA != 1 and sizeB != 1:
- # TODO: only assertion error is bound in C++ compilation right now
- raise AssertionError(
- f"The size of tensor a {sizeA} must match the size of tensor b ({sizeB}) at non-singleton dimension {i}"
- )
- expandedSizes.append(sizeB if sizeA == 1 else sizeA)
- return expandedSizes
- def broadcast_three(a: list[int], b: list[int], c: list[int]):
- return broadcast(broadcast(a, b), c)
- def broadcast_one_three(a: list[int], b: Any, c: list[int]):
- return broadcast(a, c)
- def adaptive_avg_pool2d(self: list[int], out: list[int]):
- if len(out) != 2:
- raise AssertionError(f"Expected out to have length 2, but got {len(out)}")
- if not (len(self) == 3 or len(self) == 4):
- raise AssertionError(
- f"Expected self to have length 3 or 4, but got {len(self)}"
- )
- for i in range(1, len(self)):
- if self[i] == 0:
- raise AssertionError(f"Expected self[{i}] to be non-zero, but got 0")
- shape: list[int] = []
- for i in range(0, len(self) - 2):
- shape.append(self[i])
- for elem in out:
- shape.append(elem)
- return shape
- def _copy(self: list[int]):
- out: list[int] = []
- for elem in self:
- out.append(elem)
- return out
- def unary(self: list[int]):
- return _copy(self)
- def broadcast_inplace(a: list[int], b: list[int]):
- dimsA = len(a)
- dimsB = len(b)
- if dimsB > dimsA:
- raise AssertionError(
- f"The dims of tensor b ({dimsB}) must be less than or equal to the dims of tensor a ({dimsA}) "
- )
- for dimA in range(dimsA):
- dimB = dimsB - dimsA + dimA
- sizeA = a[dimA]
- sizeB = b[dimB] if (dimB >= 0) else 1
- if sizeA != sizeB and sizeB != 1:
- # TODO: only assertion error is bound in C++ compilation right now
- raise AssertionError(
- "The size of tensor a {} must match the size of tensor b ("
- "{}) at non-singleton dimension {}".format(sizeA, sizeB, dimA)
- )
- return _copy(a)
- def expand(self: list[int], sizes: list[int]):
- if len(sizes) < len(self):
- raise AssertionError(
- f"Expected len(sizes) ({len(sizes)}) >= len(self) ({len(self)})"
- )
- ndim = len(sizes)
- tensor_dim = len(self)
- if ndim == 0:
- return _copy(sizes)
- out: list[int] = []
- for i in range(ndim):
- offset = ndim - 1 - i
- dim = tensor_dim - 1 - offset
- size = self[dim] if dim >= 0 else 1
- targetSize = sizes[i]
- if targetSize == -1:
- if dim < 0:
- raise AssertionError(f"Expected dim ({dim}) >= 0 when targetSize is -1")
- targetSize = size
- if size != targetSize:
- if size != 1:
- raise AssertionError(
- f"Expected size ({size}) == 1 when size != targetSize ({targetSize})"
- )
- size = targetSize
- out.append(size)
- return out
- def expand_one_unused(self: list[int], sizes: list[int], inp0: Any):
- return expand(self, sizes)
- def infer_size_impl(shape: list[int], numel: int) -> list[int]:
- newsize = 1
- infer_dim: Optional[int] = None
- for dim in range(len(shape)):
- if shape[dim] == -1:
- if infer_dim is not None:
- raise AssertionError("only one dimension can be inferred")
- infer_dim = dim
- elif shape[dim] >= 0:
- newsize *= shape[dim]
- else:
- raise AssertionError("invalid shape dimensions")
- if not (
- numel == newsize
- or (infer_dim is not None and newsize > 0 and numel % newsize == 0)
- ):
- raise AssertionError("invalid shape")
- out = _copy(shape)
- if infer_dim is not None:
- out[infer_dim] = numel // newsize
- return out
- def numel(sizes: list[int]):
- numel = 1
- for elem in sizes:
- numel *= elem
- return numel
- def view(self: list[int], sizes: list[int]):
- return infer_size_impl(sizes, numel(self))
- def view_one_unused(self: list[int], sizes: list[int], *, implicit: bool = False):
- return view(self, sizes)
- def sum_mean_dim(
- self: list[int], opt_dims: Optional[list[int]], keep_dim: bool, dt: Any
- ):
- out: list[int] = []
- if opt_dims is None or len(opt_dims) == 0:
- dims: list[int] = list(range(len(self)))
- else:
- dims = opt_dims
- for idx in range(len(self)):
- is_mean_dim: bool = False
- for reduce_dim in dims:
- if idx == maybe_wrap_dim(reduce_dim, len(self)):
- is_mean_dim = True
- if is_mean_dim:
- if keep_dim:
- out.append(1)
- else:
- out.append(self[idx])
- return out
- def max_dim(self: list[int], dim: int, keep_dim: bool):
- out = sum_mean_dim(self, [dim], keep_dim, None)
- return out, out
- # note: python already rounds down towards negative infinity on integer division, special arithmetic not needed
- def div_rtn(x: int, y: int):
- return x // y
- def pooling_output_shape_pad_lr(
- inputSize: int,
- kernelSize: int,
- pad_l: int,
- pad_r: int,
- stride: int,
- dilation: int,
- ceil_mode: bool,
- ):
- outputSize = (
- div_rtn(
- inputSize
- + pad_l
- + pad_r
- - dilation * (kernelSize - 1)
- - 1
- + (stride - 1 if ceil_mode else 0),
- stride,
- )
- + 1
- )
- if ceil_mode:
- if (outputSize - 1) * stride >= inputSize + pad_l:
- outputSize = outputSize - 1
- return outputSize
- def pooling_output_shape(
- inputSize: int,
- kernelSize: int,
- pad_l: int,
- stride: int,
- dilation: int,
- ceil_mode: bool,
- ):
- if stride == 0:
- raise AssertionError("stride should not be zero")
- return pooling_output_shape_pad_lr(
- inputSize, kernelSize, pad_l, pad_l, stride, dilation, ceil_mode
- )
- def pool2d_shape_check(
- input: list[int],
- kH: int,
- kW: int,
- dH: int,
- dW: int,
- padH: int,
- padW: int,
- dilationH: int,
- dilationW: int,
- nInputPlane: int,
- inputHeight: int,
- inputWidth: int,
- outputHeight: int,
- outputWidth: int,
- ):
- ndim = len(input)
- if not (kW > 0 and kH > 0):
- raise AssertionError(f"Expected kW ({kW}) > 0 and kH ({kH}) > 0")
- if not (dW > 0 and dH > 0):
- raise AssertionError(f"Expected dW ({dW}) > 0 and dH ({dH}) > 0")
- if not (dilationH > 0 and dilationW > 0):
- raise AssertionError(
- f"Expected dilationH ({dilationH}) > 0 and dilationW ({dilationW}) > 0"
- )
- valid_dims = input[1] != 0 and input[2] != 0
- if not (
- ndim == 3
- and input[0] != 0
- and valid_dims
- or (ndim == 4 and valid_dims and input[3] != 0)
- ):
- raise AssertionError(f"Invalid input dimensions: ndim={ndim}, input={input}")
- if not (kW // 2 >= padW and kH // 2 >= padH):
- raise AssertionError(
- f"Expected kW//2 ({kW // 2}) >= padW ({padW}) and "
- f"kH//2 ({kH // 2}) >= padH ({padH})"
- )
- if not (outputWidth >= 1 and outputHeight >= 1):
- raise AssertionError(
- f"Expected outputWidth ({outputWidth}) >= 1 and "
- f"outputHeight ({outputHeight}) >= 1"
- )
- def max_pool2d(
- input: list[int],
- kernel_size: list[int],
- stride: list[int],
- padding: list[int],
- dilation: list[int],
- ceil_mode: bool,
- ):
- if not (len(kernel_size) == 1 or len(kernel_size) == 2):
- raise AssertionError(
- "max_pool2d: kernel_size must either be a single int, or a tuple of two ints"
- )
- kH = kernel_size[0]
- kW = kH if len(kernel_size) == 1 else kernel_size[1]
- if not (len(stride) == 0 or len(stride) == 1 or len(stride) == 2):
- raise AssertionError(
- "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints"
- )
- dH = kH if len(stride) == 0 else stride[0]
- if len(stride) == 0:
- dW = kW
- elif len(stride) == 1:
- dW = dH
- else:
- dW = stride[1]
- if not (len(padding) == 1 or len(padding) == 2):
- raise AssertionError(
- "max_pool2d: padding must either be a single int, or a tuple of two ints"
- )
- padH = padding[0]
- padW = padH if len(padding) == 1 else padding[1]
- if not (len(dilation) == 1 or len(dilation) == 2):
- raise AssertionError(
- "max_pool2d: dilation must be either a single int, or a tuple of two ints"
- )
- dilationH = dilation[0]
- dilationW = dilationH if len(dilation) == 1 else dilation[1]
- if not (len(input) == 3 or len(input) == 4):
- raise AssertionError(f"Expected input length 3 or 4, but got {len(input)}")
- nbatch = input[-4] if len(input) == 4 else 1
- nInputPlane = input[-3]
- inputHeight = input[-2]
- inputWidth = input[-1]
- outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode)
- outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode)
- pool2d_shape_check(
- input,
- kH,
- kW,
- dH,
- dW,
- padH,
- padW,
- dilationH,
- dilationW,
- nInputPlane,
- inputHeight,
- inputWidth,
- outputHeight,
- outputWidth,
- )
- if len(input) == 3:
- return [nInputPlane, outputHeight, outputWidth]
- else:
- return [nbatch, nInputPlane, outputHeight, outputWidth]
- def max_pool2d_with_indices(
- input: list[int],
- kernel_size: list[int],
- stride: list[int],
- padding: list[int],
- dilation: list[int],
- ceil_mode: bool,
- ):
- out = max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
- return (out, out)
- def upsample_nearest2d(
- input: list[int],
- output_size: Optional[list[int]],
- scale_factors: Optional[list[float]],
- ):
- out: list[int] = []
- out.append(input[0])
- out.append(input[1])
- if scale_factors is None and output_size is None:
- raise AssertionError("Either output_size or scale_factors must be presented")
- if output_size is not None:
- if scale_factors is not None:
- raise AssertionError(
- "Must specify exactly one of output_size and scale_factors"
- )
- if len(output_size) != 2:
- raise AssertionError(
- f"Expected output_size to have length 2, but got {len(output_size)}"
- )
- out.append(output_size[0])
- out.append(output_size[1])
- if scale_factors is not None:
- if output_size is not None:
- raise AssertionError(
- "Must specify exactly one of output_size and scale_factors"
- )
- if len(scale_factors) != 2:
- raise AssertionError(
- f"Expected scale_factors to have length 2, but got {len(scale_factors)}"
- )
- out.append(int(input[2] * scale_factors[0]))
- out.append(int(input[3] * scale_factors[1]))
- return out
- def mm(self: list[int], mat2: list[int]):
- if len(self) != 2:
- raise AssertionError(f"self must be a matrix (got {len(self)} dimensions)")
- if len(mat2) != 2:
- raise AssertionError(f"mat2 must be a matrix (got {len(mat2)} dimensions)")
- if self[1] != mat2[0]:
- raise AssertionError(
- f"Matrix dimensions don't match for mm: self[1]={self[1]}, mat2[0]={mat2[0]}"
- )
- return [self[0], mat2[1]]
- def dot(self: list[int], tensor: list[int]):
- if not (len(self) == 1 and len(tensor) == 1):
- raise AssertionError(
- f"Expected 1D tensors for dot, got len(self)={len(self)}, "
- f"len(tensor)={len(tensor)}"
- )
- if self[0] != tensor[0]:
- raise AssertionError(
- f"Dot product dimension mismatch: self[0]={self[0]}, tensor[0]={tensor[0]}"
- )
- out: list[int] = []
- return out
- def mv(self: list[int], vec: list[int]):
- if not (len(self) == 2 and len(vec) == 1):
- raise AssertionError(
- f"Expected 2D matrix and 1D vector, got len(self)={len(self)}, "
- f"len(vec)={len(vec)}"
- )
- if self[1] != vec[0]:
- raise AssertionError(
- f"Matrix-vector dimension mismatch: self[1]={self[1]}, vec[0]={vec[0]}"
- )
- # TODO: return self
- return [self[0]]
- def unsqueeze(li: list[int], dim: int):
- dim = maybe_wrap_dim(dim, len(li) + 1)
- out = _copy(li)
- out.insert(dim, 1)
- return out
- def squeeze_nodim(li: list[int]):
- out: list[int] = []
- for i in range(len(li)):
- if li[i] != 1:
- out.append(li[i])
- return out
- def squeeze(li: list[int], dim: int):
- out: list[int] = []
- wrapped_dim = maybe_wrap_dim(dim, len(li))
- for i in range(len(li)):
- if i == wrapped_dim:
- if li[i] != 1:
- out.append(li[i])
- else:
- out.append(li[i])
- return out
- def squeeze_dims(li: list[int], dims: list[int]):
- if len(dims) == 0:
- return li
- wrapped_dims = _copy(dims)
- for i in range(len(dims)):
- wrapped_dims[i] = maybe_wrap_dim(wrapped_dims[i], len(li))
- result: list[int] = []
- for i in range(len(li)):
- if li[i] == 1:
- if i not in wrapped_dims:
- result.append(li[i])
- else:
- result.append(li[i])
- return result
- def index_select(self: list[int], dim: int, index: list[int]):
- dim = maybe_wrap_dim(dim, len(self))
- numel = multiply_integers(index)
- if len(index) > 1:
- raise AssertionError(f"Expected len(index) <= 1, but got {len(index)}")
- if not (dim == 0 or dim < len(self)):
- raise AssertionError(
- f"Expected dim ({dim}) == 0 or dim < len(self) ({len(self)})"
- )
- result_size: list[int] = []
- for i in range(len(self)):
- if dim == i:
- result_size.append(numel)
- else:
- result_size.append(self[i])
- return result_size
- def embedding(
- weight: list[int],
- indices: list[int],
- padding_idx: int = -1,
- scale_grad_by_freq: bool = False,
- sparse: bool = False,
- ):
- if len(weight) != 2:
- raise AssertionError(f"Expected weight to be 2D, but got {len(weight)}D")
- if len(indices) == 1:
- return index_select(weight, 0, indices)
- size = _copy(indices)
- size.append(weight[1])
- return size
- def max_int():
- return 9223372036854775807
- def slice(
- self: list[int], dim: int, start: Optional[int], end: Optional[int], step: int
- ):
- ndim = len(self)
- if ndim == 0:
- raise AssertionError("Cannot slice a 0-dimensional tensor")
- dim = maybe_wrap_dim(dim, ndim)
- start_val = start if start is not None else 0
- end_val = end if end is not None else max_int()
- if step <= 0:
- raise AssertionError(f"Expected step > 0, but got {step}")
- if start_val == max_int():
- start_val = 0
- if start_val < 0:
- start_val += self[dim]
- if end_val < 0:
- end_val += self[dim]
- if start_val < 0:
- start_val = 0
- elif start_val > self[dim]:
- start_val = self[dim]
- if end_val < start_val:
- end_val = start_val
- elif end_val >= self[dim]:
- end_val = self[dim]
- slice_len = end_val - start_val
- out = _copy(self)
- out[dim] = (slice_len + step - 1) // step
- return out
- def check_cat_no_zero_dim(tensors: list[list[int]]):
- for tensor in tensors:
- if len(tensor) <= 0:
- raise AssertionError("Cannot concatenate tensor with 0 dimensions")
- def legacy_cat_wrap_dim(dim: int, tensor_sizes: list[list[int]]):
- out_dim: Optional[int] = None
- for size in tensor_sizes:
- if not (len(size) == 1 and size[0] == 0):
- if out_dim is None:
- out_dim = maybe_wrap_dim(dim, len(size))
- if out_dim is None:
- out_dim = dim
- return out_dim
- def should_skip(tensor: list[int]):
- return numel(tensor) == 0 and len(tensor) == 1
- def check_cat_shape_except_dim(
- first: list[int], second: list[int], dimension: int, index: int
- ):
- first_dims = len(first)
- second_dims = len(second)
- if first_dims != second_dims:
- raise AssertionError(
- f"Tensors must have same number of dimensions, got {first_dims} and "
- f"{second_dims}"
- )
- for dim in range(0, first_dims):
- if dim != dimension:
- if first[dim] != second[dim]:
- raise AssertionError(
- f"Sizes of tensors must match except in dimension {dimension}, "
- f"got {first[dim]} and {second[dim]} at dimension {dim}"
- )
- def cat(tensors: list[list[int]], dim: int):
- check_cat_no_zero_dim(tensors)
- dim = legacy_cat_wrap_dim(dim, tensors)
- if len(tensors) <= 0:
- raise AssertionError("Cannot concatenate empty list of tensors")
- not_skipped_tensor: Optional[list[int]] = None
- for tensor in tensors:
- if not should_skip(tensor):
- not_skipped_tensor = tensor
- if not_skipped_tensor is None:
- return [0]
- cat_dim_size = 0
- for i in range(len(tensors)):
- tensor = tensors[i]
- if not should_skip(tensor):
- check_cat_shape_except_dim(not_skipped_tensor, tensor, dim, i)
- cat_dim_size = cat_dim_size + tensor[dim]
- result_size = _copy(not_skipped_tensor)
- result_size[dim] = cat_dim_size
- return result_size
- def stack(tensors: list[list[int]], dim: int):
- unsqueezed_tensors: list[list[int]] = []
- for tensor in tensors:
- unsqueezed = unsqueeze(tensor, dim)
- unsqueezed_tensors.append(unsqueezed)
- return cat(unsqueezed_tensors, dim)
- def select(self: list[int], dim: int, index: int):
- ndim = len(self)
- if ndim == 0:
- raise AssertionError("Cannot select from a 0-dimensional tensor")
- dim = maybe_wrap_dim(dim, ndim)
- size = self[dim]
- if index < -size or index >= size:
- raise AssertionError(
- f"Index {index} is out of bounds for dimension {dim} with size {size}"
- )
- if index < 0:
- index += size
- out: list[int] = []
- for i in range(ndim):
- if i != dim:
- out.append(self[i])
- return out
- def matmul(tensor1: list[int], tensor2: list[int]):
- dim_tensor1 = len(tensor1)
- dim_tensor2 = len(tensor2)
- if dim_tensor1 == 1 and dim_tensor2 == 1:
- return dot(tensor1, tensor2)
- elif dim_tensor1 == 2 and dim_tensor2 == 1:
- return mv(tensor1, tensor2)
- elif dim_tensor1 == 1 and dim_tensor2 == 2:
- return squeeze(mm(unsqueeze(tensor1, 0), tensor2), 0)
- elif dim_tensor1 == 2 and dim_tensor2 == 2:
- return mm(tensor1, tensor2)
- elif dim_tensor1 >= 1 and dim_tensor2 >= 1:
- # We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list);
- # we track m1 vs m2 separately even though they must match for nicer error messages
- n = tensor1[-2] if dim_tensor1 > 1 else 1
- batch_tensor1: list[int] = []
- # TODO: handling of slice
- for i in range(dim_tensor1 - 2):
- batch_tensor1.append(tensor1[i])
- p = tensor2[-1]
- batch_tensor2: list[int] = []
- # TODO: handling of slice
- for i in range(dim_tensor2 - 2):
- batch_tensor2.append(tensor2[i])
- # expand the batch portion (i.e. cut off matrix dimensions and expand rest)
- expand_batch_portion = broadcast(batch_tensor1, batch_tensor2)
- # todo: copy ?
- output_shape = expand_batch_portion
- if dim_tensor1 > 1:
- output_shape.append(n)
- if dim_tensor2 > 1:
- output_shape.append(p)
- return output_shape
- else:
- raise AssertionError("both arguments to matmul need to be at least 1D")
- def t(self: list[int]):
- if len(self) > 2:
- raise AssertionError(
- f"Expected tensor to have <= 2 dimensions, but got {len(self)}"
- )
- self_len = len(self)
- if self_len == 0:
- out: list[int] = []
- return out
- elif self_len == 1:
- return [self[0]]
- else:
- return [self[1], self[0]]
- def transpose(self: list[int], dim0: int, dim1: int):
- ndims = len(self)
- dim0 = maybe_wrap_dim(dim0, ndims)
- dim1 = maybe_wrap_dim(dim1, ndims)
- if dim0 == dim1:
- return _copy(self)
- out: list[int] = []
- for i in range(ndims):
- if i == dim0:
- out.append(self[dim1])
- elif i == dim1:
- out.append(self[dim0])
- else:
- out.append(self[i])
- return out
- def linear(input: list[int], weight: list[int], bias: Optional[list[int]]):
- out = matmul(input, t(weight))
- if bias is not None:
- if broadcast(bias, out) != out:
- raise AssertionError(
- f"Bias shape {bias} is not broadcastable to output shape {out}"
- )
- return out
- def addmm(self: list[int], mat1: list[int], mat2: list[int], beta: Any, alpha: Any):
- return broadcast(self, mm(mat1, mat2))
- def check_non_negative(array: list[int]) -> bool:
- # TODO: look into rewriting with early return and getting loop unrolling to fire
- non_negative = False
- for val in array:
- if val < 0:
- non_negative = True
- return non_negative
- def check_shape_forward(
- input: list[int],
- weight_sizes: list[int],
- bias: Optional[list[int]],
- stride: list[int],
- padding: list[int],
- dilation: list[int],
- groups: int,
- ):
- k = len(input)
- weight_dim = len(weight_sizes)
- # TODO: assertions could be expanded with the error messages
- if check_non_negative(padding):
- raise AssertionError(f"Padding must be non-negative, got {padding}")
- if check_non_negative(stride):
- raise AssertionError(f"Stride must be non-negative, got {stride}")
- if weight_dim != k:
- raise AssertionError(f"Expected weight_dim ({weight_dim}) == k ({k})")
- if weight_sizes[0] < groups:
- raise AssertionError(
- f"Expected weight_sizes[0] ({weight_sizes[0]}) >= groups ({groups})"
- )
- if (weight_sizes[0] % groups) != 0:
- raise AssertionError(
- f"Expected weight_sizes[0] ({weight_sizes[0]}) to be divisible by "
- f"groups ({groups})"
- )
- # only handling not transposed
- if input[1] != weight_sizes[1] * groups:
- raise AssertionError(
- f"Expected input[1] ({input[1]}) == weight_sizes[1] * groups "
- f"({weight_sizes[1] * groups})"
- )
- if bias is not None and not (len(bias) == 1 and bias[0] == weight_sizes[0]):
- raise AssertionError(
- f"Expected bias to be None or have shape [1] with value "
- f"weight_sizes[0]={weight_sizes[0]}, got {bias}"
- )
- for i in range(2, k):
- if (input[i] + 2 * padding[i - 2]) < (
- dilation[i - 2] * (weight_sizes[i] - 1) + 1
- ):
- raise AssertionError(
- f"Calculated padded input size ({input[i] + 2 * padding[i - 2]}) "
- f"is smaller than effective kernel size "
- f"({dilation[i - 2] * (weight_sizes[i] - 1) + 1}) at dimension {i}"
- )
- # this is not handling transposed convolution yet
- def conv_output_size(
- input_size: list[int],
- weight_size: list[int],
- bias: Optional[list[int]],
- stride: list[int],
- padding: list[int],
- dilation: list[int],
- groups: int,
- ):
- check_shape_forward(
- input_size, weight_size, bias, stride, padding, dilation, groups
- )
- has_dilation = len(dilation) > 0
- dim = len(input_size)
- output_size: list[int] = []
- input_batch_size_dim = 0
- weight_output_channels_dim = 0
- output_size.append(input_size[input_batch_size_dim])
- output_size.append(weight_size[weight_output_channels_dim])
- for d in range(2, dim):
- dilation_ = dilation[d - 2] if has_dilation else 1
- kernel = dilation_ * (weight_size[d] - 1) + 1
- output_size.append(
- (input_size[d] + (2 * padding[d - 2]) - kernel) // stride[d - 2] + 1
- )
- return output_size
- def conv1d(
- input: list[int],
- weight: list[int],
- bias: Optional[list[int]],
- stride: list[int],
- padding: list[int],
- dilation: list[int],
- groups: int,
- ):
- if len(weight) != 3:
- raise AssertionError(f"Expected 3D weight for conv1d, got {len(weight)}D")
- if len(input) != 3:
- raise AssertionError(f"Expected 3D input for conv1d, got {len(input)}D")
- return conv_output_size(input, weight, bias, stride, padding, dilation, groups)
- def conv2d(
- input: list[int],
- weight: list[int],
- bias: Optional[list[int]],
- stride: list[int],
- padding: list[int],
- dilation: list[int],
- groups: int,
- ):
- if len(weight) != 4:
- raise AssertionError(f"Expected 4D weight for conv2d, got {len(weight)}D")
- if len(input) != 4:
- raise AssertionError(f"Expected 4D input for conv2d, got {len(input)}D")
- return conv_output_size(input, weight, bias, stride, padding, dilation, groups)
- def conv_backwards(
- grad_output: list[int],
- input: list[int],
- weight: list[int],
- biases: Optional[list[int]],
- ):
- # Bias gradient is always generated regardess of if biases is supplied
- return _copy(input), _copy(weight), [grad_output[1]]
- def conv_transpose2d_input(
- input: list[int],
- weight: list[int],
- bias: Optional[list[int]] = None,
- stride: Optional[list[int]] = None,
- padding: Optional[list[int]] = None,
- output_padding: Optional[list[int]] = None,
- groups: int = 1,
- dilation: Optional[list[int]] = None,
- ) -> list[int]:
- if stride is None:
- stride = [1, 1]
- if padding is None:
- padding = [0, 0]
- if output_padding is None:
- output_padding = [0, 0]
- if dilation is None:
- dilation = [1, 1]
- has_dilation = len(dilation) > 0
- dim = len(input)
- output_size: list[int] = []
- input_batch_size_dim = 0
- weight_output_channels_dim = 1
- output_size.append(input[input_batch_size_dim])
- output_size.append(weight[weight_output_channels_dim] * groups)
- for d in range(2, dim):
- dilation_ = dilation[d - 2] if has_dilation else 1
- kernel = dilation_ * (weight[d] - 1)
- output_size.append(
- (input[d] - 1) * stride[d - 2]
- - 2 * padding[d - 2]
- + kernel
- + output_padding[d - 2]
- + 1
- )
- return output_size
- def conv_forwards(
- input: list[int],
- weight: list[int],
- bias: Optional[list[int]],
- stride: list[int],
- padding: list[int],
- dilation: list[int],
- transposed: bool,
- output_padding: list[int],
- groups: int,
- ) -> list[int]:
- has_dilation = len(dilation) > 0
- has_output_padding = len(output_padding) > 0
- dim = len(input)
- output_size: list[int] = []
- input_batch_size_dim = 0
- weight_output_channels_dim = 1 if transposed else 0
- output_size.append(input[input_batch_size_dim])
- if transposed:
- output_size.append(weight[weight_output_channels_dim] * groups)
- else:
- output_size.append(weight[weight_output_channels_dim])
- for d in range(2, dim):
- dilation_ = dilation[d - 2] if has_dilation else 1
- output_padding_ = output_padding[d - 2] if has_output_padding else 0
- if transposed:
- kernel = dilation_ * (weight[d] - 1)
- output_size.append(
- (input[d] - 1) * stride[d - 2]
- - 2 * padding[d - 2]
- + kernel
- + output_padding_
- + 1
- )
- else:
- kernel = dilation_ * (weight[d] - 1) + 1
- output_size.append(
- (input[d] + (2 * padding[d - 2]) - kernel) // stride[d - 2] + 1
- )
- return output_size
- def _conv_forwards(
- input: list[int],
- weight: list[int],
- bias: Optional[list[int]],
- stride: list[int],
- padding: list[int],
- dilation: list[int],
- transposed: bool,
- output_padding: list[int],
- groups: int,
- benchmark: bool,
- deterministic: bool,
- cudnn_enabled: bool,
- allow_tf32: bool,
- ) -> list[int]:
- return conv_forwards(
- input,
- weight,
- bias,
- stride,
- padding,
- dilation,
- transposed,
- output_padding,
- groups,
- )
- def batch_norm(
- input: list[int],
- weight: Optional[list[int]],
- bias: Optional[list[int]],
- running_mean: Optional[list[int]],
- running_var: Optional[list[int]],
- training: bool,
- momentum: float,
- eps: float,
- cudnn_enabled: bool,
- ):
- out: list[int] = []
- for elem in input:
- out.append(elem)
- return out
- def conv3d(
- input: list[int],
- weight: list[int],
- bias: Optional[list[int]],
- stride: list[int],
- padding: list[int],
- dilation: list[int],
- groups: int,
- ):
- if len(weight) != 5:
- raise AssertionError(f"Expected 5D weight for conv3d, got {len(weight)}D")
- if len(input) != 5:
- raise AssertionError(f"Expected 5D input for conv3d, got {len(input)}D")
- return conv_output_size(input, weight, bias, stride, padding, dilation, groups)
- def maybe_wrap_dim(dim: int, dim_post_expr: int, wrap_scalar: bool = True):
- if dim_post_expr <= 0:
- if not wrap_scalar:
- raise AssertionError(
- "Expected wrap_scalar to be True when dim_post_expr <= 0"
- )
- dim_post_expr = 1
- min = -dim_post_expr
- max = dim_post_expr - 1
- if dim < min or dim > max:
- raise AssertionError(
- f"Dimension {dim} out of range (expected to be in range [{min}, {max}])"
- )
- if dim < 0:
- dim += dim_post_expr
- return dim
- def zero_dim_tensor(input: Any):
- out: list[int] = []
- return out
- def multiply_integers(li: list[int]):
- out = 1
- for elem in li:
- out = out * elem
- return out
- def arange_end(end: number, inp0: Any, inp1: Any, inp2: Any, inp3: Any):
- if end < 0:
- raise AssertionError(f"Expected end ({end}) >= 0")
- return [int(math.ceil(end))]
- def arange_start(
- start: number, end: number, inp0: Any, inp1: Any, inp2: Any, inp3: Any
- ):
- if end < 0:
- raise AssertionError(f"Expected end ({end}) >= 0")
- if end < start:
- raise AssertionError(f"Expected end ({end}) >= start ({start})")
- return [int(math.ceil(end - start))]
- def arange_start_step(
- start: number, end: number, step: number, inp0: Any, inp1: Any, inp2: Any, inp3: Any
- ):
- if step == 0:
- raise AssertionError("step must not be zero")
- if step < 0:
- if start < end:
- raise AssertionError(
- f"Expected start ({start}) >= end ({end}) when step < 0"
- )
- else:
- if end < start:
- raise AssertionError(
- f"Expected end ({end}) >= start ({start}) when step > 0"
- )
- return [int(math.ceil((end - start) / step))]
- def permute(input: list[int], dims: list[int]):
- if len(input) != len(dims):
- raise AssertionError(
- f"Expected len(input) ({len(input)}) == len(dims) ({len(dims)})"
- )
- ndim = len(dims)
- seen_dims: list[int] = []
- newSizes: list[int] = []
- for i in range(ndim):
- dim = maybe_wrap_dim(dims[i], ndim)
- seen_dims.append(dim)
- newSizes.append(input[dim])
- for i in range(1, ndim):
- for j in range(i):
- if seen_dims[i] == seen_dims[j]:
- raise AssertionError(
- f"Repeated dimension {seen_dims[i]} in permute dimensions"
- )
- return newSizes
- def movedim(self: list[int], source: list[int], destination: list[int]) -> list[int]:
- self_dim = len(self)
- if self_dim <= 1:
- return self
- normalized_src: list[int] = []
- normalized_dst: list[int] = []
- for i in range(len(source)):
- normalized_src.append(maybe_wrap_dim(source[i], self_dim))
- normalized_dst.append(maybe_wrap_dim(destination[i], self_dim))
- order = [-1 for i in range(self_dim)]
- src_dims = [i for i in range(self_dim)]
- dst_dims = [i for i in range(self_dim)]
- for i in range(len(source)):
- order[normalized_dst[i]] = normalized_src[i]
- src_dims[normalized_src[i]] = -1
- dst_dims[normalized_dst[i]] = -1
- source_dims: list[int] = []
- destination_dims: list[int] = []
- for ele in src_dims:
- if ele != -1:
- source_dims.append(ele)
- for ele in dst_dims:
- if ele != -1:
- destination_dims.append(ele)
- rest_dim = self_dim - len(source)
- for i in range(rest_dim):
- order[destination_dims[i]] = source_dims[i]
- return permute(self, order)
- def flatten(input: list[int], start_dim: int, end_dim: int):
- start_dim = maybe_wrap_dim(start_dim, len(input))
- end_dim = maybe_wrap_dim(end_dim, len(input))
- if start_dim > end_dim:
- raise AssertionError(f"Expected start_dim ({start_dim}) <= end_dim ({end_dim})")
- if len(input) == 0:
- return [1]
- if start_dim == end_dim:
- # TODO: return self
- out: list[int] = []
- for elem in input:
- out.append(elem)
- return out
- slice_numel = 1
- for i in range(start_dim, end_dim + 1):
- slice_numel *= input[i]
- # TODO: use slicing when slice optimization has landed
- # slice_numel = multiply_integers(input[start_dim:end_dim - start_dim + 1])
- shape: list[int] = []
- for i in range(start_dim):
- shape.append(input[i])
- shape.append(slice_numel)
- for i in range(end_dim + 1, len(input)):
- shape.append(input[i])
- return shape
- def nonzero_lower_bound(input: list[int]):
- return [0, len(input)]
- def nonzero_upper_bound(input: list[int]):
- return [numel(input), len(input)]
- def _reduce_along_dim(self: list[int], dim: int, keepdim: bool):
- dim = maybe_wrap_dim(dim, len(self))
- out: list[int] = []
- for i, self_dim in enumerate(self):
- if i == dim:
- if keepdim:
- out.append(1)
- else:
- out.append(self_dim)
- return out
- def argmax(
- self: list[int], dim: Optional[int] = None, keepdim: bool = False
- ) -> list[int]:
- if dim is None:
- return []
- return _reduce_along_dim(self, dim, keepdim)
- def bmm(self: list[int], mat2: list[int]) -> list[int]:
- if len(self) != 3:
- raise AssertionError(f"bmm only supports 3D tensors, got {len(self)}D")
- if len(mat2) != 3:
- raise AssertionError(f"bmm only supports 3D tensors, got {len(mat2)}D")
- if self[0] != mat2[0]:
- raise AssertionError(
- f"mismatching batch dimension: self[0]={self[0]}, mat2[0]={mat2[0]}"
- )
- if self[2] != mat2[1]:
- raise AssertionError(
- f"mismatching contracting dimension: self[2]={self[2]}, mat2[1]={mat2[1]}"
- )
- return [self[0], self[1], mat2[2]]
- def _shape_as_tensor(self: list[int]) -> list[int]:
- return [len(self)]
- def topk(self: list[int], k: int, dim: int = -1) -> tuple[list[int], list[int]]:
- if len(self) == 0:
- result: list[int] = []
- else:
- if k > self[dim]:
- raise AssertionError(
- f"k ({k}) is too big for dimension {dim} of size {self[dim]}"
- )
- result = _copy(self)
- result[dim] = k
- return result, result
- def nll_loss_forward(
- self: list[int], target: list[int], weight: Optional[list[int]], reduction: int
- ) -> tuple[list[int], list[int]]:
- # This is taken shamelessly from the meta function in LossNLL.cpp
- self_dim = len(self)
- target_dim = len(target)
- if not (0 < self_dim <= 2):
- raise AssertionError(f"Expected 0 < self_dim <= 2, but got self_dim={self_dim}")
- if target_dim > 1:
- raise AssertionError(f"Expected target_dim <= 1, but got {target_dim}")
- no_batch_dim = self_dim == 1 and target_dim == 0
- if not (no_batch_dim or (self[0] == target[0])):
- raise AssertionError(
- f"Batch size mismatch: self[0]={self[0]}, target[0]={target[0]}"
- )
- n_classes = self[-1]
- scalar_shape: list[int] = []
- if weight is not None and not (len(weight) == 1 and weight[0] == n_classes):
- raise AssertionError(
- f"Expected weight to be None or have shape [n_classes], "
- f"got {weight} with n_classes={n_classes}"
- )
- if reduction == 0 and self_dim == 2:
- reduction_shape = [self[0]]
- else:
- reduction_shape = scalar_shape
- return reduction_shape, scalar_shape
- def native_layer_norm(
- input: list[int], normalized_shape: list[int]
- ) -> tuple[list[int], list[int], list[int]]:
- reduction_shape: list[int] = []
- num_unreduced_dimensions = len(input) - len(normalized_shape)
- if num_unreduced_dimensions < 0:
- raise AssertionError(
- f"Expected len(input) ({len(input)}) >= len(normalized_shape) "
- f"({len(normalized_shape)})"
- )
- for i in range(num_unreduced_dimensions):
- reduction_shape.append(input[i])
- for i in range(num_unreduced_dimensions, len(input)):
- reduction_shape.append(1)
- return _copy(input), reduction_shape, reduction_shape
- def native_batch_norm(
- input: list[int],
- weight: Optional[list[int]],
- bias: Optional[list[int]],
- running_mean: Optional[list[int]],
- running_var: Optional[list[int]],
- training: bool,
- ) -> tuple[list[int], list[int], list[int]]:
- if training:
- _size = [input[1]]
- else:
- _size = [0]
- return _copy(input), _size, _size
- def _batch_norm_with_update(
- input: list[int],
- weight: Optional[list[int]],
- bias: Optional[list[int]],
- running_mean: Optional[list[int]],
- running_var: Optional[list[int]],
- ) -> tuple[list[int], list[int], list[int], list[int]]:
- _size = [input[1]]
- return _copy(input), _size, _size, [0]
- def cross_entropy_loss(
- self: list[int],
- target: list[int],
- weight: Optional[list[int]] = None,
- reduction: int = 1,
- ignore_index: int = -100,
- label_smoothing: float = 0.0,
- ) -> list[int]:
- result_shape = nll_loss_forward(self, target, weight, reduction)[0]
- return result_shape
- """
- Currently deferring the enabling of this, as part of the propoasal to suspend
- adding ops.
- There are currently cases in the test case where this is being called
- in the SSA opinfo tests with with unexpected values (eg list of two ints, see the first
- opinfo test). The behavior of index is significantly dependent on the inputs.
- This could be an error with how we are matching up shape functions, or that this
- function needs to just implement everything.
- def index_Tensor(self: List[int], indices: List[Optional[List[int]]]) -> List[int]:
- assert len(indices) <= len(self), "More indices than dimensions to index"
- broadcasted_shape: List[int] = []
- for index_tensor_shape in indices:
- if index_tensor_shape is not None:
- broadcasted_shape = broadcast(broadcasted_shape, index_tensor_shape)
- return broadcasted_shape
- """
- ScriptFn = torch._C.ScriptFunction
- shape_compute_graph_mapping: dict[str, ScriptFn] = {}
- bounded_compute_graph_mapping: dict[str, tuple[ScriptFn, ScriptFn]] = {}
- script_func_map: dict[Callable, ScriptFn] = {}
- def process_func(func: Callable):
- if func not in script_func_map:
- scripted_func = torch.jit.script(func)
- torch._C._jit_pass_inline(scripted_func.graph)
- for _ in range(2):
- torch._C._jit_pass_peephole(scripted_func.graph)
- torch._C._jit_pass_constant_propagation(scripted_func.graph)
- script_func_map[func] = scripted_func
- return script_func_map[func]
- def add_shape_compute_mapping(operator_schema: str, func: Callable):
- global shape_compute_graph_mapping
- shape_compute_graph_mapping[operator_schema] = process_func(func)
- def add_bounded_compute_mapping(
- operator_schema: str, lower_bound_func: Callable, upper_bound_func: Callable
- ):
- # Adds a shape compute function for both upper and lower bounds
- fns = (process_func(lower_bound_func), process_func(upper_bound_func))
- bounded_compute_graph_mapping[operator_schema] = fns
- add_shape_compute_mapping(
- "aten::contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a)",
- unary,
- )
- add_shape_compute_mapping(
- "aten::rsub.Tensor(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", unary
- )
- add_shape_compute_mapping(
- "aten::dropout(Tensor input, float p, bool train) -> Tensor", unary
- )
- add_shape_compute_mapping(
- "aten::adaptive_avg_pool2d(Tensor self, int[2] output_size) -> Tensor",
- adaptive_avg_pool2d,
- )
- add_shape_compute_mapping(
- "prim::NumToTensor.Scalar(Scalar a) -> Tensor", zero_dim_tensor
- )
- add_shape_compute_mapping("prim::NumToTensor.bool(bool a) -> Tensor", zero_dim_tensor)
- add_shape_compute_mapping(
- "aten::zeros(int[] size, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)",
- unary,
- )
- add_shape_compute_mapping(
- "aten::to.dtype(Tensor(a) self, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))",
- unary,
- )
- add_shape_compute_mapping(
- "aten::arange(Scalar end, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)",
- arange_end,
- )
- add_shape_compute_mapping(
- "aten::arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor",
- arange_start,
- )
- add_shape_compute_mapping(
- "aten::arange.start_step(Scalar start, Scalar end, Scalar step, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor",
- arange_start_step,
- )
- add_shape_compute_mapping("aten::squeeze(Tensor(a) self) -> Tensor(a)", squeeze_nodim)
- add_shape_compute_mapping(
- "aten::squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)", squeeze
- )
- add_shape_compute_mapping(
- "aten::squeeze.dims(Tensor(a) self, int[] dim) -> Tensor(a)", squeeze_dims
- )
- add_shape_compute_mapping(
- "aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a)", unsqueeze
- )
- add_shape_compute_mapping(
- "aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a)",
- slice,
- )
- add_shape_compute_mapping(
- "aten::select.int(Tensor(a) self, int dim, int index) -> Tensor(a)", select
- )
- add_shape_compute_mapping(
- "aten::index_select(Tensor self, int dim, Tensor index) -> Tensor", index_select
- )
- add_shape_compute_mapping(
- "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, "
- "float eps=1e-05, bool cudnn_enable=True) -> Tensor",
- unary,
- )
- add_shape_compute_mapping(
- "aten::softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor", unary
- )
- add_shape_compute_mapping(
- "aten::_no_grad_embedding_renorm_(Tensor weight, Tensor input, float max_norm, float norm_type) -> Tensor",
- unary,
- )
- add_shape_compute_mapping(
- "aten::embedding_renorm_(Tensor(a!) self, Tensor indices, float max_norm, float norm_type) -> Tensor(a!)",
- unary,
- )
- add_shape_compute_mapping(
- "aten::embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor",
- embedding,
- )
- add_shape_compute_mapping("aten::mm(Tensor self, Tensor mat2) -> Tensor", mm)
- add_shape_compute_mapping("aten::dot(Tensor self, Tensor tensor) -> Tensor", dot)
- add_shape_compute_mapping("aten::mv(Tensor self, Tensor vec) -> Tensor", mv)
- add_shape_compute_mapping("aten::matmul(Tensor self, Tensor other) -> Tensor", matmul)
- add_shape_compute_mapping(
- "aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor", linear
- )
- add_shape_compute_mapping(
- "aten::max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor",
- max_pool2d,
- )
- add_shape_compute_mapping(
- "aten::max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)",
- max_pool2d_with_indices,
- )
- add_shape_compute_mapping("aten::t(Tensor(a) self) -> Tensor(a)", t)
- add_shape_compute_mapping(
- "aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)", transpose
- )
- add_shape_compute_mapping(
- "aten::conv1d(Tensor input, Tensor weight, Tensor? bias=None, int[1] stride=1, int[1] padding=0, int[1] dilation=1, int groups=1) -> Tensor",
- conv1d,
- )
- add_shape_compute_mapping(
- "aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor",
- conv2d,
- )
- add_shape_compute_mapping(
- "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor",
- batch_norm,
- )
- add_shape_compute_mapping(
- "aten::conv3d(Tensor input, Tensor weight, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] dilation=1, int groups=1) -> Tensor",
- conv3d,
- )
- add_shape_compute_mapping(
- "aten::convolution_backward(Tensor grad_output, Tensor input, Tensor weight, int[]? bias_sizes, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)",
- conv_backwards,
- )
- add_shape_compute_mapping(
- "aten::convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor",
- conv_forwards,
- )
- add_shape_compute_mapping(
- "aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor",
- _conv_forwards,
- )
- add_shape_compute_mapping(
- "aten::conv_transpose2d.input(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] output_padding=0, int groups=1, int[2] dilation=1) -> Tensor",
- conv_transpose2d_input,
- )
- add_shape_compute_mapping(
- "aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)",
- flatten,
- )
- add_shape_compute_mapping("aten::cat(Tensor[] tensors, int dim=0) -> Tensor", cat)
- add_shape_compute_mapping("aten::stack(Tensor[] tensors, int dim=0) -> Tensor", stack)
- add_shape_compute_mapping(
- "aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)", permute
- )
- add_shape_compute_mapping(
- "aten::movedim.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a)",
- movedim,
- )
- add_shape_compute_mapping("aten::view(Tensor(a) self, int[] size) -> Tensor(a)", view)
- add_shape_compute_mapping(
- "aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)", expand
- )
- add_shape_compute_mapping(
- "aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> Tensor(a)",
- expand_one_unused,
- )
- add_shape_compute_mapping(
- "aten::mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor",
- sum_mean_dim,
- )
- add_shape_compute_mapping(
- "aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor",
- sum_mean_dim,
- )
- add_shape_compute_mapping(
- "aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)",
- max_dim,
- )
- add_shape_compute_mapping(
- "aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor", zero_dim_tensor
- )
- add_shape_compute_mapping(
- "aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor", zero_dim_tensor
- )
- add_shape_compute_mapping(
- "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor",
- addmm,
- )
- add_shape_compute_mapping(
- "aten::upsample_nearest2d.vec(Tensor input, int[]? output_size, float[]? scale_factors) -> (Tensor)",
- upsample_nearest2d,
- )
- add_shape_compute_mapping(
- "aten::quantize_per_tensor(Tensor self, float scale, int zero_point, ScalarType dtype) -> Tensor",
- unary,
- )
- add_shape_compute_mapping(
- "aten::quantize_per_tensor.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, ScalarType dtype) -> Tensor",
- unary,
- )
- add_shape_compute_mapping("aten::dequantize(Tensor self) -> Tensor", unary)
- add_shape_compute_mapping(
- "quantized::add(Tensor qa, Tensor qb, float scale, int zero_point) -> Tensor qc",
- broadcast,
- )
- add_shape_compute_mapping(
- "aten::argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor", argmax
- )
- add_shape_compute_mapping("aten::bmm(Tensor self, Tensor mat2) -> Tensor", bmm)
- add_shape_compute_mapping(
- "aten::_shape_as_tensor(Tensor self) -> Tensor", _shape_as_tensor
- )
- add_shape_compute_mapping(
- "aten::topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)",
- topk,
- )
- add_shape_compute_mapping(
- "aten::nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> (Tensor output, Tensor total_weight)",
- nll_loss_forward,
- )
- add_shape_compute_mapping(
- "aten::native_layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)",
- native_layer_norm,
- )
- add_shape_compute_mapping(
- "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)",
- native_batch_norm,
- )
- add_shape_compute_mapping(
- "aten::_native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)",
- native_batch_norm,
- )
- add_shape_compute_mapping(
- "aten::_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)",
- native_batch_norm,
- )
- add_shape_compute_mapping(
- "_batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)",
- _batch_norm_with_update,
- )
- add_shape_compute_mapping(
- "aten::cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor",
- cross_entropy_loss,
- )
- # add_shape_compute_mapping("aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor", index_Tensor)
- # TODO: migrate over all of symbolic_shape_registry_util.cpp
- # These are duplicated here so that the functions will be serialized
- add_shape_compute_mapping(
- "aten::lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor",
- broadcast_three,
- )
- add_shape_compute_mapping(
- "aten::where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor",
- broadcast_one_three,
- )
- add_shape_compute_mapping(
- "aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)",
- broadcast_inplace,
- )
- # quantized_conv_prepack TODO
- # Shape Compute Fn with upper and lower bounds
- add_bounded_compute_mapping(
- "aten::nonzero(Tensor self) -> (Tensor)", nonzero_lower_bound, nonzero_upper_bound
- )
|