ops.py 95 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756
  1. # mypy: allow-untyped-defs
  2. import functools
  3. import math
  4. import operator
  5. from typing import * # noqa: F403
  6. import torch
  7. import torch.nn.functional as F
  8. from torch.fx.operator_schemas import normalize_function
  9. from torch.nested._internal.sdpa import jagged_scaled_dot_product_attention
  10. from .nested_tensor import NestedTensor
  11. __all__: list[Any] = []
  12. JAGGED_OPS_TABLE: Dict[Any, Any] = {}
  13. def _get_padding_value(dtype, padding_type):
  14. if dtype.is_floating_point:
  15. return (
  16. torch.finfo(dtype).max if padding_type == "max" else torch.finfo(dtype).min
  17. )
  18. else:
  19. # For integer dtypes, use infinity sentinels which the C++ implementation
  20. # clamps to dtype min/max, avoiding precision loss through double.
  21. return float("inf") if padding_type == "max" else float("-inf")
  22. def _outer_to_inner_dim(ndim, dim, ragged_dim, canonicalize=False):
  23. from torch._prims_common import canonicalize_dims
  24. if isinstance(dim, (tuple, list)):
  25. output = type(dim)(_outer_to_inner_dim(ndim, d, ragged_dim) for d in dim)
  26. # ensure no duplicates, which can result from both batch and ragged mapping to 0
  27. return type(output)(dict.fromkeys(output))
  28. if canonicalize:
  29. dim = canonicalize_dims(ndim, dim)
  30. if not (dim >= 0 and dim < ndim): # pyrefly: ignore [unsupported-operation]
  31. raise AssertionError(f"dim {dim} out of range for ndim {ndim}")
  32. # Map dim=0 (AKA batch dim) -> packed dim i.e. outer ragged dim - 1.
  33. # For other dims, subtract 1 to convert to inner space.
  34. return (
  35. # pyrefly: ignore [unsupported-operation]
  36. ragged_dim - 1 if dim == 0 else dim - 1
  37. )
  38. def _wrap_jagged_dim(
  39. ndim,
  40. dim,
  41. ragged_dim,
  42. op_name,
  43. convert_to_inner_dim=True,
  44. allow_ragged_dim=False,
  45. allow_batch_dim=False,
  46. ):
  47. from torch._prims_common import canonicalize_dims
  48. wrapped = canonicalize_dims(ndim, dim)
  49. if wrapped == ragged_dim and not allow_ragged_dim:
  50. raise RuntimeError(f"{op_name}(): not supported for NestedTensor on ragged dim")
  51. elif wrapped == 0 and not allow_batch_dim:
  52. raise RuntimeError(f"{op_name}(): not supported for NestedTensor on dim=0")
  53. ret = (
  54. _outer_to_inner_dim(ndim, wrapped, ragged_dim)
  55. if convert_to_inner_dim
  56. else wrapped
  57. )
  58. if allow_batch_dim:
  59. # Need to disambiguate whether we're operating on the batch dim or not.
  60. # Operating on dim=1 -> dim=0 after the inner dim conversion.
  61. operating_on_batch = wrapped == 0
  62. return (ret, operating_on_batch)
  63. return ret
  64. def _wrap_jagged_dims(ndim, dims, op_name, ragged_idx=1):
  65. """
  66. For NestedTensor operators,
  67. wraps dimensions to non-negative values,
  68. and returns metadata related to reduction dimension(s).
  69. """
  70. from torch._prims_common import canonicalize_dims
  71. if not isinstance(dims, (tuple, list)):
  72. raise AssertionError(
  73. f"_wrap_jagged_dims(): cannot iterate over dimensions of type {type(dims)}"
  74. )
  75. wrapped_dims = [
  76. canonicalize_dims(ndim, d) for d in dims
  77. ] # convert all indices to non-negative values
  78. operate_on_batch = 0 in wrapped_dims
  79. operate_on_ragged = ragged_idx in wrapped_dims
  80. operate_on_non_batch = any(d != 0 and d != ragged_idx for d in wrapped_dims)
  81. # ensure no duplicates, which can result from both batch and ragged mapping to 0
  82. outer_to_inner_dim = tuple(
  83. dict.fromkeys(_outer_to_inner_dim(ndim, d, ragged_idx) for d in wrapped_dims)
  84. )
  85. return outer_to_inner_dim, operate_on_batch, operate_on_ragged, operate_on_non_batch
  86. def check_schema(schema_str: str, func, *args, **kwargs) -> None:
  87. named_arg_types = schema_str.split(", ")
  88. num_optional_args = [x.endswith("?") for x in named_arg_types].count(True)
  89. min_args = len(named_arg_types) - num_optional_args
  90. # special case: ellipses allows for any number of unchecked args at the end
  91. if named_arg_types[-1] == "...":
  92. named_arg_types = named_arg_types[:-1]
  93. else:
  94. if not (len(args) >= min_args and len(args) <= len(named_arg_types)):
  95. raise ValueError(
  96. f"NestedTensor {func.__name__}({schema_str}): expected at least {min_args} "
  97. f"arguments and at most {len(named_arg_types)} arguments, but got: "
  98. f"{len(args)} arguments"
  99. )
  100. arg_type_check_fns = {
  101. "t": lambda x: isinstance(x, torch.Tensor) and not isinstance(x, NestedTensor),
  102. "jt": lambda x: isinstance(x, NestedTensor)
  103. and x._lengths is None
  104. and x._ragged_idx == 1, # ops with "jt" require contiguous JT only
  105. "jt_all": lambda x: isinstance(
  106. x, NestedTensor
  107. ), # ops with "jt_all" can accept all kinds of JT
  108. "any": lambda x: True,
  109. }
  110. for i, named_arg_type in enumerate(named_arg_types):
  111. name, arg_type = named_arg_type.split(": ")
  112. is_optional = arg_type.endswith("?")
  113. normalized_arg_type = arg_type[:-1] if is_optional else arg_type
  114. if normalized_arg_type not in arg_type_check_fns:
  115. raise AssertionError(f"Unknown arg type: {normalized_arg_type}")
  116. if i >= len(args):
  117. if not is_optional:
  118. raise ValueError(
  119. f"NestedTensor {func.__name__}({schema_str}) "
  120. f"missing required argument: {name}"
  121. )
  122. continue
  123. _check_fn = arg_type_check_fns[normalized_arg_type]
  124. def check_fn(x, is_optional=is_optional):
  125. if is_optional:
  126. return x is None or _check_fn(x)
  127. else:
  128. return _check_fn(x)
  129. if not check_fn(args[i]):
  130. type_to_desc = {
  131. "t": "tensor",
  132. "t?": "optional tensor",
  133. "jt": "contiguous jagged layout NestedTensor",
  134. "jt_all": "jagged layout NestedTensor",
  135. "any": "<any type>",
  136. }
  137. raise ValueError(
  138. f"NestedTensor {func.__name__}({schema_str}): expected {name} to be a "
  139. f"{type_to_desc[arg_type]}"
  140. )
  141. def check_ragged_dim_same(
  142. func, a: NestedTensor, a_name: str, b: NestedTensor, b_name: str
  143. ) -> None:
  144. # Calling into .shape here
  145. if a._size[a._ragged_idx] != b._size[b._ragged_idx]:
  146. raise RuntimeError(
  147. f"NestedTensor {func.__name__}: expected {a_name} and {b_name} to have the "
  148. "same exact offsets tensor."
  149. )
  150. # returns True if the raggedness-relevant portions of the NT shape
  151. # match those of the specified size
  152. def raggedness_matches(nt, size):
  153. end = nt._ragged_idx + 1
  154. nt_ragged = nt._size[:end]
  155. size_ragged = size[:end]
  156. return len(nt_ragged) == len(size_ragged) and (
  157. all(ns == s or s == -1 for ns, s in zip(nt_ragged, size_ragged))
  158. )
  159. def squeeze_leading_ones(t):
  160. # Note: [ Squeezing leading ones ]
  161. #
  162. # Squeeze leading ones from t.
  163. #
  164. # We want:
  165. # (B, j0, ?, ?) + (1, 1, ?, ?) -> (B, j0, ?, ?)
  166. # (B, j0, ?, ?) + (1, 1, 1, ?, ?) -> (1, B, j0, ?, ?) (not yet supported)
  167. #
  168. # 1) Squeeze extra ones and grab values from NT
  169. # (1, 1, ?, ?) -> (?, ?) and (sum(*), ?, ?) -> (B, j0, ?, ?)
  170. # 2) Do dense broadcasting:
  171. # (sum(*), ?, ?) + (?, ?) -> (sum(*), ?, ?)
  172. # 3) Construct nested tensor
  173. # (sum(*), ?, ?) -> (B, j0, ?, ?)
  174. #
  175. # If unsqueezing on the 0th dim becomes supported, we would unsqueeze
  176. # at step (4) and we would need to update this function to record how
  177. # many ones we unsqueezed.
  178. while t.dim() > 0 and t.shape[0] == 1:
  179. t = t.squeeze(0)
  180. return t
  181. def register_func(tables, aten_ops, schema_str):
  182. if not isinstance(aten_ops, list):
  183. aten_ops = [aten_ops]
  184. if not isinstance(tables, list):
  185. tables = [tables]
  186. def wrapper(func):
  187. for aten_op in aten_ops:
  188. def get_inner(aten_op):
  189. def inner(*args, **kwargs):
  190. check_schema(schema_str, func, *args, **kwargs)
  191. return func(aten_op, *args, **kwargs)
  192. return inner
  193. for table in tables:
  194. table[aten_op] = get_inner(aten_op)
  195. return func
  196. return wrapper
  197. register_jagged_func = functools.partial(register_func, JAGGED_OPS_TABLE)
  198. def lookup_jagged(func, *args, **kwargs) -> Callable | None:
  199. dispatch_func = JAGGED_OPS_TABLE.get(func, None)
  200. if dispatch_func is not None:
  201. return dispatch_func
  202. # Handle pointwise fallbacks
  203. if torch.Tag.pointwise in func.tags:
  204. from torch.fx.experimental.symbolic_shapes import is_nested_int
  205. # No pointwise ops legitimately accept nested int inputs. Without this check,
  206. # they will be incorrectly interpreted as tensors.
  207. # See https://github.com/pytorch/pytorch/issues/138496
  208. for arg in args:
  209. if is_nested_int(arg):
  210. raise RuntimeError(
  211. f"NestedTensor {func.__name__}: invalid argument {arg}"
  212. )
  213. # Assume there aren't additional tensors that aren't the "unary/binary" args
  214. num_tensor_args = sum(isinstance(x, torch.Tensor) for x in args)
  215. if num_tensor_args == 1:
  216. # Build up the check schema string. The first tensor arg is assumed to be
  217. # an NJT and other args are sent through as-is.
  218. schema_parts = []
  219. for arg in func._schema.arguments:
  220. if isinstance(arg.type, torch.TensorType):
  221. schema_parts.append(f"{arg.name}: jt_all")
  222. break
  223. else:
  224. schema_parts.append(f"{arg.name}: any")
  225. schema_parts.append("...")
  226. check_schema_str = ", ".join(schema_parts)
  227. check_schema(check_schema_str, func, *args, **kwargs)
  228. return functools.partial(jagged_unary_pointwise, func)
  229. elif num_tensor_args == 2:
  230. check_schema("lhs: any, rhs: any, ...", func, *args, **kwargs)
  231. return functools.partial(jagged_binary_pointwise, func)
  232. return None
  233. def extract_kwargs(arg):
  234. kwargs = {
  235. "offsets": arg.offsets(),
  236. "lengths": arg.lengths(),
  237. "_metadata_cache": arg._metadata_cache,
  238. "_ragged_idx": arg._ragged_idx,
  239. }
  240. return kwargs
  241. def jagged_unary_pointwise(func, *args, **kwargs):
  242. # assume if we get here that there is a single NJT input in the args
  243. njt = next(arg for arg in args if isinstance(arg, NestedTensor))
  244. return NestedTensor(
  245. func(*(arg._values if arg is njt else arg for arg in args), **kwargs),
  246. **extract_kwargs(njt),
  247. )
  248. def jagged_binary_pointwise(func, *args, **kwargs):
  249. a, b = args[0], args[1]
  250. if not (isinstance(a, NestedTensor) or isinstance(b, NestedTensor)):
  251. raise AssertionError("At least one of the arguments must be a NestedTensor")
  252. mismatch_error_msg = (
  253. "cannot call binary pointwise function {} with inputs of shapes {} and {}"
  254. )
  255. # a is NT, b is NT
  256. if isinstance(a, NestedTensor) and isinstance(b, NestedTensor):
  257. # ex: (B, j0, D) + (B, j0, D)
  258. # ex: (B, j0, D) + (B, j0, 1)
  259. if raggedness_matches(a, b._size):
  260. return NestedTensor(
  261. func(a._values, b._values, *args[2:], **kwargs), **extract_kwargs(a)
  262. )
  263. raise RuntimeError(mismatch_error_msg.format(func.__name__, a._size, b._size))
  264. # either a is NT or b is NT at this point
  265. a_is_nt = isinstance(a, NestedTensor)
  266. extracted_kwargs = extract_kwargs(a) if a_is_nt else extract_kwargs(b)
  267. # === Handle broadcasting across the batch / ragged dims ===
  268. # Easy case: take advantage of pre-existing broadcasting logic
  269. # ex: (B, j0, ?, ?) + (?) -> (B, j0, ?, ?)
  270. # ex: (B, j0, ?, ?) + (?, ?) -> (B, j0, ?, ?)
  271. # ex: (B, j0, ?, ?) + (1, 1, ?, ?) -> (B, j0, ?, ?)
  272. nt, t = (a, b) if a_is_nt else (b, a)
  273. # See Note: [ Squeezing leading ones ]
  274. if t.dim() > nt.dim():
  275. raise NotImplementedError("NYI: broadcasting NT with T with larger dim")
  276. t_squeezed = squeeze_leading_ones(t)
  277. if nt.dim() >= t_squeezed.dim() + 2:
  278. lhs, rhs = (nt._values, t_squeezed) if a_is_nt else (t_squeezed, nt._values)
  279. return NestedTensor(func(lhs, rhs, *args[2:], **kwargs), **extracted_kwargs)
  280. # Harder case: do manual broadcasting when NT dim == non-NT dim
  281. # ex: (B, j0, D_0, D_1) + (B, 1, D_0, D_1) -> (B, j0, D_0, D_1)
  282. if a.dim() == b.dim():
  283. # ex: (B, j0, D_0, D_1) + (1, 1, D_0, D_1) -> should
  284. # be (B, j0, D_0, D_1) but not yet supported
  285. if a.shape[0] != b.shape[0]:
  286. raise RuntimeError(
  287. mismatch_error_msg.format(func.__name__, a.shape, b.shape)
  288. )
  289. from .nested_tensor import nested_from_padded
  290. # handle broadcasting via padded dense -> jagged conversion
  291. min_seqlen = nt._maybe_min_seqlen
  292. max_seqlen = nt._maybe_max_seqlen
  293. padded_max_S = max_seqlen
  294. total_L = nt._values.shape[nt._ragged_idx - 1]
  295. if padded_max_S is None:
  296. # use upper bound on max seqlen if it's not present
  297. padded_max_S = total_L
  298. # convert dense tensor -> jagged
  299. t = t.expand(
  300. [x if i != nt._ragged_idx else padded_max_S for i, x in enumerate(t.shape)]
  301. )
  302. t_as_nt = nested_from_padded(
  303. t,
  304. offsets=nt._offsets,
  305. ragged_idx=nt._ragged_idx,
  306. sum_S=total_L,
  307. min_seqlen=min_seqlen,
  308. max_seqlen=max_seqlen,
  309. )
  310. # function call with two NJTs
  311. lhs, rhs = (nt, t_as_nt) if a_is_nt else (t_as_nt, nt)
  312. return func(lhs, rhs, *args[2:], **kwargs)
  313. # ex: (B, j0, D_0, D_1) + (A, B, 1, D_0, D_1) -> error because this breaks the invariant
  314. # that ragged dim is wrt left-most batch dim
  315. raise RuntimeError(mismatch_error_msg.format(func.__name__, a.shape, b.shape))
  316. def jagged_torch_function(func, *args, **kwargs):
  317. # SDPA has special kernels that handle nested tensors.
  318. # Dispatch to the correct implementation here
  319. if func is torch._C._nn.scaled_dot_product_attention:
  320. return jagged_scaled_dot_product_attention(*args, **kwargs)
  321. if func.__name__ == "apply_":
  322. func(args[0]._values, *args[1:], **kwargs)
  323. return args[0]
  324. # Handle flatten() here because it's CompositeImplicit.
  325. if func.__name__ == "flatten":
  326. def _flatten_sig(input, start_dim=0, end_dim=-1) -> None:
  327. pass
  328. _, new_kwargs = normalize_function( # type: ignore[misc]
  329. _flatten_sig, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  330. )
  331. inp = new_kwargs.pop("input")
  332. # NB: stay in outer dim space because we're going to redispatch on a NT input
  333. start_dim = _wrap_jagged_dim(
  334. inp.dim(),
  335. new_kwargs["start_dim"],
  336. inp._ragged_idx,
  337. "flatten",
  338. convert_to_inner_dim=False,
  339. )
  340. end_dim = _wrap_jagged_dim(
  341. inp.dim(),
  342. new_kwargs["end_dim"],
  343. inp._ragged_idx,
  344. "flatten",
  345. convert_to_inner_dim=False,
  346. )
  347. if start_dim == end_dim:
  348. return inp
  349. product = functools.reduce(operator.mul, inp.shape[start_dim : end_dim + 1])
  350. new_shape = (*inp.shape[:start_dim], product, *inp.shape[end_dim + 1 :])
  351. return inp.reshape(*new_shape)
  352. # Handle NestedTensor share_memory_.
  353. if func.__name__ == "share_memory_":
  354. nt = args[0]
  355. if nt.is_cuda:
  356. return nt
  357. names, _ = nt.__tensor_flatten__()
  358. with torch._C.DisableTorchFunctionSubclass():
  359. for name in names:
  360. component = getattr(nt, name, None)
  361. if component is not None:
  362. component.share_memory_()
  363. return nt
  364. # Handle NestedTensor is_shared.
  365. if func.__name__ == "is_shared":
  366. nt = args[0]
  367. if nt.is_cuda:
  368. return False
  369. names, _ = nt.__tensor_flatten__()
  370. if not names:
  371. return False
  372. return all(
  373. getattr(nt, name) is not None and getattr(nt, name).is_shared()
  374. for name in names
  375. )
  376. # Handle nested-specific input validation for CompositeImplicit rms_norm
  377. if func.__name__ == "rms_norm":
  378. def _rms_norm_sig(input, normalized_shape, weight=None, eps=None) -> None:
  379. pass
  380. _, new_kwargs = normalize_function( # type: ignore[misc]
  381. _rms_norm_sig, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  382. )
  383. inp = new_kwargs.pop("input")
  384. normalized_shape = new_kwargs.pop("normalized_shape")
  385. # can't normalize over the ragged dim (yet)
  386. max_normalizable = inp.dim() - inp._ragged_idx - 1
  387. if len(normalized_shape) > max_normalizable:
  388. raise ValueError(
  389. "rms_norm(): Normalization over the ragged dim not supported for nested tensors"
  390. )
  391. with torch._C.DisableTorchFunctionSubclass():
  392. return func(*args, **kwargs)
  393. raise NotImplementedError(func)
  394. @register_jagged_func(
  395. [
  396. torch.ops.aten.is_non_overlapping_and_dense.default,
  397. torch.ops.aten.sym_size.default,
  398. torch.ops.aten.dim.default,
  399. torch.ops.aten.numel.default,
  400. torch.ops.aten.sym_numel.default,
  401. torch.ops.aten.sym_stride.default,
  402. torch.ops.aten.sym_storage_offset.default,
  403. ],
  404. "self: jt_all",
  405. )
  406. def tensor_attr_supported_getter(func, *args, **kwargs):
  407. if func is torch.ops.aten.is_non_overlapping_and_dense.default:
  408. return False
  409. if func is torch.ops.aten.sym_size.default:
  410. return args[0]._size
  411. if func is torch.ops.aten.dim.default:
  412. return len(args[0]._size)
  413. if func in (torch.ops.aten.sym_numel.default, torch.ops.aten.numel.default):
  414. if args[0]._lengths is not None:
  415. return int(sum(args[0]._lengths) * math.prod(args[0]._size[2:]))
  416. return args[0]._values.numel()
  417. if func is torch.ops.aten.sym_stride.default:
  418. return args[0]._strides
  419. if func is torch.ops.aten.sym_storage_offset.default:
  420. return args[0]._values.storage_offset()
  421. @register_jagged_func(torch.ops.prim.layout.default, "self: jt_all")
  422. def prim_layout_default(func, *args, **kwargs):
  423. return torch.jagged
  424. @register_jagged_func(
  425. [torch.ops.aten.size.default],
  426. "self: jt_all",
  427. )
  428. def tensor_attr_unsupported_getter(func, *args, **kwargs) -> None:
  429. if func is torch.ops.aten.size.default:
  430. raise RuntimeError(
  431. "NestedTensor does not support directly calling torch.ops.aten.size; "
  432. "please use `nested_tensor.size()` instead."
  433. )
  434. @register_jagged_func(torch.ops.aten.is_contiguous.default, "self: jt_all")
  435. def is_contiguous_general(func, *args, **kwargs):
  436. from torch._prims_common import is_contiguous_for_memory_format
  437. _, new_kwargs = normalize_function( # type: ignore[misc]
  438. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  439. )
  440. inp = new_kwargs.pop("input")
  441. # If created from narrow() check for lengths
  442. if inp.lengths() is not None:
  443. return False
  444. new_kwargs["memory_format"] = new_kwargs.get(
  445. "memory_format", torch.contiguous_format
  446. )
  447. if new_kwargs["memory_format"] == torch.preserve_format:
  448. return True
  449. return is_contiguous_for_memory_format(inp._values, **new_kwargs)
  450. register_jagged_func(
  451. torch.ops.aten.is_contiguous.memory_format, "self: jt_all, memory_format: any?"
  452. )(is_contiguous_general)
  453. @register_jagged_func(
  454. torch.ops.aten.sym_is_contiguous.default, "self: jt_all, memory_format: any?"
  455. )
  456. def sym_is_contiguous_general(func, *args, **kwargs):
  457. _, new_kwargs = normalize_function( # type: ignore[misc]
  458. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  459. )
  460. inp = new_kwargs.pop("input")
  461. # If created from narrow() check for lengths
  462. if inp.lengths() is not None:
  463. return False
  464. new_kwargs["memory_format"] = new_kwargs.get(
  465. "memory_format", torch.contiguous_format
  466. )
  467. if new_kwargs["memory_format"] == torch.preserve_format:
  468. return True
  469. return torch.ops.aten.sym_is_contiguous.default(inp._values, **new_kwargs)
  470. @register_jagged_func(
  471. torch.ops.aten.clone.default, "input: jt_all, memory_format: any?"
  472. )
  473. def clone_default(func, *args, **kwargs):
  474. _, new_kwargs = normalize_function( # type: ignore[misc]
  475. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  476. )
  477. inp = new_kwargs.pop("input")
  478. new_meta = extract_kwargs(inp)
  479. if inp._lengths is not None:
  480. if new_kwargs["memory_format"] == torch.contiguous_format:
  481. # need to copy to remove "holes" non-contiguity / lengths metadata
  482. # TODO: write a kernel for this
  483. from .nested_tensor import jagged_from_list
  484. # TODO: We probably want the output to have the same ragged structure / nested int.
  485. if inp._ragged_idx != 1:
  486. raise AssertionError(
  487. "NJT with ragged_idx != 1 not supported for contiguous clone"
  488. )
  489. contig, _ = jagged_from_list(inp.unbind(), offsets=None)
  490. return contig
  491. return NestedTensor(func(inp._values, **new_kwargs), **new_meta)
  492. @register_jagged_func(torch.ops.aten.linear.default, "input: jt, weight: t, bias: t?")
  493. def linear_default(func, *args, **kwargs):
  494. _, new_kwargs = normalize_function( # type: ignore[misc]
  495. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  496. )
  497. inp = new_kwargs.pop("input")
  498. return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
  499. @register_jagged_func(
  500. torch.ops.aten.linear_backward.default,
  501. "self: jt, grad_output: jt, weight: t, output_mask: any",
  502. )
  503. def linear_backward_default(func, *args, **kwargs):
  504. _, new_kwargs = normalize_function( # type: ignore[misc]
  505. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  506. )
  507. inp = new_kwargs.pop("input")
  508. grad_output = new_kwargs.pop("grad_output")
  509. weight = new_kwargs.pop("weight")
  510. output_mask = new_kwargs.pop("output_mask")
  511. ds, dw, db = None, None, None
  512. check_ragged_dim_same(func, inp, "self", grad_output, "grad_output")
  513. if output_mask[0]:
  514. ds = NestedTensor(
  515. torch.matmul(grad_output._values, weight), **extract_kwargs(grad_output)
  516. )
  517. if output_mask[1]:
  518. # NB: Fold dims of values for input and grad_output to treat them as 2D. This
  519. # trick avoids materializing large intermediates and immediately reducing over
  520. # them via sum(). This is equivalent to computing:
  521. # torch.matmul(grad_output._values.transpose(-2, -1), inp._values)
  522. # and then summing over the leading dimensions to get a 2D weight grad.
  523. grad_2d = grad_output._values.reshape(-1, weight.size(0))
  524. input_2d = inp._values.reshape(-1, weight.size(1))
  525. dw = torch.matmul(grad_2d.t(), input_2d)
  526. if output_mask[2]:
  527. # Sum over all but the last dim to get a 1D bias grad. We cannot
  528. # rely on the autograd engine to reduce for us, because returning a
  529. # tensor aliasing the input would violate the aten signature annotation
  530. reduce_dims = tuple(range(grad_output._values.ndim - 1))
  531. if reduce_dims == ():
  532. db = grad_output._values.clone()
  533. else:
  534. db = torch.sum(grad_output._values, reduce_dims, keepdim=False)
  535. return (ds, dw, db)
  536. @register_jagged_func(torch.ops.aten.to.dtype, "input: jt_all, dtype: any")
  537. def to_dtype(func, *args, **kwargs):
  538. _, new_kwargs = normalize_function( # type: ignore[misc]
  539. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  540. )
  541. inp = new_kwargs.pop("input")
  542. return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
  543. @register_jagged_func(torch.ops.aten._to_copy.default, "self: jt_all")
  544. def to_copy_default(func, *args, **kwargs):
  545. from .nested_tensor import _tensor_symint_registry
  546. _, new_kwargs = normalize_function( # type: ignore[misc]
  547. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  548. )
  549. inp = new_kwargs.pop("input")
  550. # don't change layout
  551. new_kwargs.pop("layout")
  552. new_values = func(inp._values, **new_kwargs)
  553. new_offsets = inp._offsets.to(device=new_values.device)
  554. new_lengths = None
  555. if inp._lengths is not None:
  556. new_lengths = inp._lengths.to(device=new_values.device)
  557. from torch._subclasses.fake_tensor import FakeTensor
  558. from torch._subclasses.functional_tensor import (
  559. FunctionalTensor,
  560. mb_unwrap_functional_tensor,
  561. )
  562. ragged_source = inp._offsets if inp._lengths is None else inp._lengths
  563. new_thing = new_offsets if new_lengths is None else new_lengths
  564. if isinstance(new_thing, (FakeTensor, FunctionalTensor)):
  565. # Temporary hack until we have the union find
  566. tgt = mb_unwrap_functional_tensor(new_thing)
  567. src = mb_unwrap_functional_tensor(ragged_source)
  568. # pyrefly: ignore[missing-attribute]
  569. tgt.nested_int_memo = src.nested_int_memo
  570. else:
  571. _tensor_symint_registry[new_thing] = _tensor_symint_registry[ragged_source]
  572. inp_kwargs = extract_kwargs(inp)
  573. inp_kwargs["offsets"] = new_offsets
  574. inp_kwargs["lengths"] = new_lengths
  575. output = NestedTensor(new_values, **inp_kwargs)
  576. return output
  577. @register_jagged_func(
  578. torch.ops.aten.copy_.default, "self: jt_all, src: jt_all, non_blocking: any?"
  579. )
  580. def copy_default(func, *args, **kwargs):
  581. _, new_kwargs = normalize_function( # type: ignore[misc]
  582. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  583. )
  584. inp = new_kwargs.pop("input")
  585. src = new_kwargs.pop("src")
  586. if inp._size != src._size:
  587. # try to recursively copy_ on unbound components to get around nested int mismatch
  588. # TODO: eventually do a direct copy when this is possible
  589. inp_comps = inp.unbind()
  590. inp_comp_shapes = [c.shape for c in inp_comps]
  591. src_comps = src.unbind()
  592. src_comp_shapes = [c.shape for c in src_comps]
  593. if inp_comp_shapes != src_comp_shapes:
  594. raise RuntimeError(
  595. "copy_(): expected compatible input and src shapes, but got: "
  596. f"{inp.shape} and {src.shape}"
  597. )
  598. for inp_comp, src_comp in zip(inp_comps, src_comps):
  599. inp_comp.copy_(src_comp)
  600. # AOTD allows mutations of inputs only, (not views of the inputs).
  601. # NJT.values() returns _values.detach() to workaround some issues.
  602. # To keep mutation in the graph, AOTD manually calls copy_ on the input (NJT).
  603. # Here we directly mutate self._values to not emit .detach() in the graph, which would make it non-compilable.
  604. inp._values.copy_(src._values)
  605. return inp
  606. register_jagged_func(torch.ops.aten.detach.default, "self: jt_all")(
  607. jagged_unary_pointwise
  608. )
  609. @register_jagged_func(
  610. [
  611. torch.ops.aten.empty_like.default,
  612. torch.ops.aten.ones_like.default,
  613. torch.ops.aten.zeros_like.default,
  614. torch.ops.aten.rand_like.default,
  615. torch.ops.aten.randn_like.default,
  616. ],
  617. "self: jt_all",
  618. )
  619. def like_factory_default(func, *args, **kwargs):
  620. _, new_kwargs = normalize_function( # type: ignore[misc]
  621. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  622. )
  623. inp = new_kwargs.pop("input")
  624. # Default layout is technically torch.strided but only jagged is supported here.
  625. # Rather than force users to specify the layout, assume jagged.
  626. # This should be set to strided for redispatching on values.
  627. new_kwargs["layout"] = torch.strided
  628. new_values = func(inp._values, **new_kwargs)
  629. new_offsets = inp._offsets.to(device=new_values.device)
  630. new_lengths = None
  631. if inp._lengths is not None:
  632. new_lengths = inp._lengths.to(device=new_values.device)
  633. output_kwargs = extract_kwargs(inp)
  634. if "offsets" in output_kwargs:
  635. output_kwargs["offsets"] = new_offsets
  636. if "lengths" in output_kwargs:
  637. output_kwargs["lengths"] = new_lengths
  638. if inp.device != new_values.device:
  639. # Update the nested int registry to indicate that the ragged structure is the same
  640. # between the two offsets / lengths on different devices.
  641. from torch._subclasses.fake_tensor import FakeTensor
  642. from torch._subclasses.functional_tensor import (
  643. FunctionalTensor,
  644. mb_unwrap_functional_tensor,
  645. )
  646. from .nested_tensor import _tensor_symint_registry
  647. ragged_source = inp._offsets if inp._lengths is None else inp._lengths
  648. new_thing = new_offsets if new_lengths is None else new_lengths
  649. if isinstance(new_thing, (FakeTensor, FunctionalTensor)):
  650. # Temporary hack until we have the union find
  651. tgt = mb_unwrap_functional_tensor(new_thing)
  652. src = mb_unwrap_functional_tensor(ragged_source)
  653. # pyrefly: ignore[missing-attribute]
  654. tgt.nested_int_memo = src.nested_int_memo
  655. else:
  656. _tensor_symint_registry[new_thing] = _tensor_symint_registry[ragged_source]
  657. return NestedTensor(new_values, **output_kwargs)
  658. register_jagged_func(torch.ops.aten.full_like.default, "self: jt_all, fill_value: any")(
  659. like_factory_default
  660. )
  661. register_jagged_func(torch.ops.aten.randint_like.default, "self: jt_all, high: any")(
  662. like_factory_default
  663. )
  664. register_jagged_func(
  665. torch.ops.aten.randint_like.low_dtype, "self: jt_all, low: any, high: any"
  666. )(like_factory_default)
  667. @register_jagged_func(torch.ops.aten.zero_.default, "self: jt_all")
  668. def zero__default(func, *args, **kwargs):
  669. _, new_kwargs = normalize_function( # type: ignore[misc]
  670. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  671. )
  672. inp = new_kwargs.pop("input")
  673. func(inp._values)
  674. return inp
  675. @register_jagged_func(
  676. torch.ops.aten._softmax.default, "self: jt_all, dim: any, half_to_float: any"
  677. )
  678. def _softmax_default(func, *args, **kwargs):
  679. _, new_kwargs = normalize_function( # type: ignore[misc]
  680. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  681. )
  682. if isinstance(new_kwargs["dim"], tuple):
  683. raise RuntimeError(
  684. "softmax(): not supported for dimensions of type 'tuple' for NestedTensor"
  685. )
  686. inp = new_kwargs.pop("input")
  687. (
  688. new_kwargs["dim"],
  689. reduce_on_batch,
  690. reduce_on_ragged,
  691. _reduce_on_non_batch,
  692. ) = _wrap_jagged_dims(
  693. inp.dim(),
  694. (new_kwargs["dim"],),
  695. "softmax",
  696. inp._ragged_idx,
  697. )
  698. if reduce_on_batch:
  699. raise RuntimeError(
  700. "softmax(): not supported when reducing across the batch dimension for NestedTensor"
  701. )
  702. if reduce_on_ragged and inp._ragged_idx > 1:
  703. raise RuntimeError(
  704. "softmax(): not supported when reducing along the ragged dimension for ragged_idx > 1 for NestedTensor"
  705. )
  706. if reduce_on_ragged and inp._lengths is not None:
  707. raise RuntimeError(
  708. "softmax(): not supported where lengths is not None "
  709. + "if reducing across the ragged dimension for NestedTensor"
  710. )
  711. new_kwargs["dim"] = new_kwargs["dim"][
  712. 0
  713. ] # torch.softmax takes in the reduction dimension as an integer
  714. if reduce_on_ragged:
  715. padded_softmax_values = torch.nn.functional.softmax(
  716. torch.ops.aten._jagged_to_padded_dense_forward(
  717. inp._values.reshape(
  718. inp._values.shape[0], -1
  719. ), # values are required to be 2D tensors for j2pd
  720. [inp._offsets],
  721. max_lengths=[inp._max_seqlen], # max length of ragged dimension
  722. padding_value=float("-inf"), # e^-inf = 0
  723. ),
  724. dim=inp._ragged_idx,
  725. )
  726. softmax_values = torch.ops.aten._padded_dense_to_jagged_forward(
  727. padded_softmax_values,
  728. [inp._offsets],
  729. total_L=inp._values.shape[
  730. 0
  731. ], # providing this parameter helps avoid a GPU/CPU sync
  732. ).reshape(
  733. -1, *inp._values.shape[1:]
  734. ) # expand softmax_values back to original shape (inp._values.shape)
  735. return NestedTensor(softmax_values, **extract_kwargs(inp))
  736. return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
  737. @register_jagged_func(
  738. torch.ops.aten._log_softmax.default, "self: jt_all, dim: any, half_to_float: any"
  739. )
  740. def _log_softmax_default(func, *args, **kwargs):
  741. _, new_kwargs = normalize_function( # type: ignore[misc]
  742. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  743. )
  744. if isinstance(new_kwargs["dim"], tuple):
  745. raise RuntimeError(
  746. "log_softmax(): not supported for dimensions of type 'tuple' for NestedTensor"
  747. )
  748. inp = new_kwargs.pop("input")
  749. (
  750. new_kwargs["dim"],
  751. reduce_on_batch,
  752. reduce_on_ragged,
  753. _reduce_on_non_batch,
  754. ) = _wrap_jagged_dims(
  755. inp.dim(), (new_kwargs["dim"],), "log_softmax", inp._ragged_idx
  756. )
  757. if reduce_on_batch:
  758. raise RuntimeError(
  759. "log_softmax(): not supported when reducing across the batch dimension for NestedTensor"
  760. )
  761. if reduce_on_ragged:
  762. raise RuntimeError(
  763. "log_softmax(): not supported when reducing along the ragged dimension for NestedTensor"
  764. )
  765. # torch.log_softmax takes in the reduction dimension as an integer
  766. new_kwargs["dim"] = new_kwargs["dim"][0]
  767. return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
  768. @register_jagged_func(
  769. torch.ops.aten._softmax_backward_data.default,
  770. "grad_output: jt, output: jt, dim: any, input_dtype: any",
  771. )
  772. def _softmax_backward(func, *args, **kwargs):
  773. _, new_kwargs = normalize_function( # type: ignore[misc]
  774. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  775. )
  776. grad_out = new_kwargs.pop("grad_output")
  777. output = new_kwargs.pop("output")
  778. return NestedTensor(
  779. func(grad_out._values, output._values, **new_kwargs), **extract_kwargs(grad_out)
  780. )
  781. @register_jagged_func(
  782. torch.ops.aten.native_dropout.default, "self: jt, float: any, train: any?"
  783. )
  784. def native_dropout_default(func, *args, **kwargs):
  785. _, new_kwargs = normalize_function( # type: ignore[misc]
  786. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  787. )
  788. inp = new_kwargs.pop("input")
  789. out1, out2 = func(inp._values, **new_kwargs)
  790. return (
  791. NestedTensor(out1, **extract_kwargs(inp)),
  792. NestedTensor(out2, **extract_kwargs(inp)),
  793. )
  794. @register_jagged_func(
  795. torch.ops.aten.native_dropout_backward.default,
  796. "grad_output: jt, mask: jt, scale: any",
  797. )
  798. def native_dropout_backward_default(func, *args, **kwargs):
  799. _, new_kwargs = normalize_function( # type: ignore[misc]
  800. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  801. )
  802. grad_output = new_kwargs.pop("grad_output")
  803. mask = new_kwargs.pop("mask")
  804. return NestedTensor(
  805. func(grad_output._values, mask._values, **new_kwargs),
  806. **extract_kwargs(grad_output),
  807. )
  808. @register_jagged_func(
  809. torch.ops.aten.prod.dim_int,
  810. "self: jt_all, dim: any, keepdim: any?, dtype: any?",
  811. )
  812. def prod_dim_int(func, *args, **kwargs):
  813. return _apply_reduction(func, "prod", 1, *args, **kwargs)
  814. @register_jagged_func(torch.ops.aten.prod.default, "self: jt_all, dtype: any?")
  815. def prod_default(func, *args, **kwargs):
  816. _, new_kwargs = normalize_function( # type: ignore[misc]
  817. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  818. )
  819. inp = new_kwargs.pop("input")
  820. return func(inp._values, **new_kwargs)
  821. @register_jagged_func(
  822. torch.ops.aten.split.Tensor, "self: jt, split_size: any, dim: any?"
  823. )
  824. def split_tensor(func, *args, **kwargs):
  825. _, new_kwargs = normalize_function( # type: ignore[misc]
  826. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  827. )
  828. inp = new_kwargs.pop("input")
  829. new_kwargs["dim"] = _wrap_jagged_dim(
  830. inp.dim(), new_kwargs["dim"], inp._ragged_idx, "split"
  831. )
  832. return tuple(
  833. NestedTensor(values=x, **extract_kwargs(inp))
  834. for x in func(inp._values, **new_kwargs)
  835. )
  836. @register_jagged_func(
  837. torch.ops.aten.split_with_sizes.default, "self: jt, split_sizes: any, dim: any?"
  838. )
  839. def split_with_sizes_default(func, *args, **kwargs):
  840. _, new_kwargs = normalize_function( # type: ignore[misc]
  841. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  842. )
  843. inp = new_kwargs.pop("input")
  844. new_kwargs["dim"] = _wrap_jagged_dim(
  845. inp.dim(), new_kwargs["dim"], inp._ragged_idx, "split_with_sizes"
  846. )
  847. return [
  848. NestedTensor(values=x, **extract_kwargs(inp))
  849. for x in func(inp._values, **new_kwargs)
  850. ]
  851. @register_jagged_func(
  852. torch.ops.aten.narrow.default, "self: jt, dim: any, start: any, length: any"
  853. )
  854. def narrow(func, *args, **kwargs):
  855. _, new_kwargs = normalize_function( # type: ignore[misc]
  856. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  857. )
  858. inp = new_kwargs.pop("input")
  859. dim = _wrap_jagged_dim(inp.dim(), new_kwargs["dim"], inp._ragged_idx, "narrow")
  860. values = func(
  861. inp._values,
  862. dim=dim,
  863. start=new_kwargs["start"],
  864. length=new_kwargs["length"],
  865. )
  866. return NestedTensor(values, **extract_kwargs(inp))
  867. @register_jagged_func(torch.ops.aten.chunk.default, "self: jt, chunks: any, dim: any?")
  868. def chunk_default(func, *args, **kwargs):
  869. _, new_kwargs = normalize_function( # type: ignore[misc]
  870. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  871. )
  872. inp = new_kwargs.pop("input")
  873. new_kwargs["dim"], operating_on_batch = _wrap_jagged_dim(
  874. inp.dim(), new_kwargs["dim"], inp._ragged_idx, "chunk", allow_batch_dim=True
  875. )
  876. if operating_on_batch:
  877. chunks = new_kwargs["chunks"]
  878. # get _offsets of the chunks
  879. lengths = inp._offsets.diff()
  880. chunked_lengths = lengths.chunk(chunks)
  881. chunked_offsets = [torch.cumsum(x, dim=0) for x in chunked_lengths]
  882. chunked_offsets = [F.pad(x, (1, 0), value=0) for x in chunked_offsets] # type: ignore[arg-type]
  883. nested_kwargs = [
  884. {"offsets": per_offsets, "_ragged_idx": inp._ragged_idx}
  885. for per_offsets in chunked_offsets
  886. ]
  887. # get _values of the chunks
  888. split_sizes = [x.sum().item() for x in chunked_lengths]
  889. chunk_values = inp._values.split(split_sizes)
  890. # Note that the actual number of chunks returned is not necessarily the same as
  891. # the input number; it can be counter-intuitive, but it matches dense behavior.
  892. return [
  893. NestedTensor(values=chunk_values[i], **(nested_kwargs[i]))
  894. for i in range(len(chunk_values))
  895. ]
  896. else:
  897. return [
  898. NestedTensor(values=x, **extract_kwargs(inp))
  899. for x in func(inp._values, **new_kwargs)
  900. ]
  901. @register_jagged_func(torch.ops.aten.unbind.int, "self: jt_all, dim: any?")
  902. def unbind_int(func, *args, **kwargs):
  903. # Note that this specializes on the length of the offsets
  904. _, new_kwargs = normalize_function( # type: ignore[misc]
  905. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  906. )
  907. dim = new_kwargs["dim"]
  908. if dim != 0:
  909. raise RuntimeError("unbind(): only supported for NestedTensor on dim=0")
  910. inp = new_kwargs.pop("input")
  911. values = inp.values()
  912. offsets = inp.offsets()
  913. lengths = inp.lengths()
  914. ragged_idx = inp._ragged_idx
  915. def _torch_check(_lengths: list[int], _offsets: list[int] | None = None) -> None:
  916. # This torch._check are needed for torch.compile
  917. # symbolic shapes processing.
  918. # offsets and lengths are symbolic variables during compilation,
  919. # we guarantee the correct offsets/lengths correspondence:
  920. # sum of lengths <= total ragged_dim_size
  921. # every length and offset are size-like variable (allows sym shapes to reason it as [2, inf))
  922. # offset[i] + length[i] <= ragged_dim_size, for unbind and split dim correctness
  923. # offsets[i] <= ragged_dim_size
  924. lengths_sum = 0
  925. ragged_dim_size = values.shape[ragged_idx - 1]
  926. for i in range(len(_lengths)):
  927. torch._check(_lengths[i] >= 0)
  928. torch._check(_lengths[i] <= ragged_dim_size)
  929. lengths_sum += _lengths[i]
  930. if _offsets is not None:
  931. torch._check(
  932. _offsets[i] + _lengths[i] <= ragged_dim_size,
  933. lambda: "unbind(): nested tensor offsets and lengths do not match ragged_idx dimension",
  934. )
  935. torch._check(lengths_sum <= ragged_dim_size)
  936. if _offsets is not None:
  937. for i in range(len(_offsets)):
  938. torch._check(_offsets[i] >= 0)
  939. torch._check(_offsets[i] <= ragged_dim_size)
  940. if lengths is None:
  941. lengths_scalars = offsets.diff().tolist()
  942. _torch_check(lengths_scalars)
  943. return torch.split(values, lengths_scalars, dim=(ragged_idx - 1))
  944. if ragged_idx <= 0:
  945. raise RuntimeError(
  946. "unbind(): nested tensor ragged_idx out of bounds (should be >= 1)"
  947. )
  948. lengths_scalars = lengths.tolist()
  949. offsets_scalars = offsets.tolist()
  950. _torch_check(lengths_scalars, offsets_scalars)
  951. return [
  952. torch.narrow(
  953. values,
  954. dim=(ragged_idx - 1),
  955. start=offsets_scalars[i],
  956. length=lengths_scalars[i],
  957. )
  958. for i in range(lengths.shape[0])
  959. ]
  960. @register_jagged_func(torch.ops.aten.squeeze.dim, "self: jt, dim: any")
  961. def squeeze_dim(func, *args, **kwargs):
  962. _, new_kwargs = normalize_function( # type: ignore[misc]
  963. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  964. )
  965. inp = new_kwargs.pop("input")
  966. values = inp._values
  967. new_kwargs["dim"] = _wrap_jagged_dim(
  968. len(inp._size), new_kwargs["dim"], inp._ragged_idx, "squeeze"
  969. )
  970. return NestedTensor(func(values, **new_kwargs), **extract_kwargs(inp))
  971. @register_jagged_func(torch.ops.aten.unsqueeze.default, "self: jt_all, dim: any")
  972. def unsqueeze_default(func, *args, **kwargs):
  973. _, new_kwargs = normalize_function( # type: ignore[misc]
  974. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  975. )
  976. inp = new_kwargs.pop("input")
  977. values = inp._values
  978. # Account for collapsed jagged dim
  979. dim = new_kwargs["dim"]
  980. new_kwargs["dim"] = _wrap_jagged_dim(
  981. len(inp._size) + 1, dim, inp._ragged_idx, "unsqueeze", allow_ragged_dim=True
  982. )
  983. # ragged_idx changes if a dimension is added before it
  984. output_kwargs = extract_kwargs(inp)
  985. if new_kwargs["dim"] <= inp._ragged_idx - 1:
  986. output_kwargs["_ragged_idx"] += 1
  987. return NestedTensor(func(values, **new_kwargs), **output_kwargs)
  988. @register_jagged_func(torch.ops.aten.cat.default, "tensors: any, dim: any?")
  989. def cat_default(func, *args, **kwargs):
  990. _, new_kwargs = normalize_function( # type: ignore[misc]
  991. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  992. )
  993. tensors = new_kwargs.pop("tensors")
  994. # Convert any non-nested to nested
  995. nested = [t for t in tensors if t.is_nested]
  996. if len(nested) == 0:
  997. raise AssertionError("At least one tensor must be nested")
  998. first = nested[0]
  999. tensors = [t if t.is_nested else t.expand_as(first) for t in tensors]
  1000. # Account for collapsed jagged dim
  1001. dim = new_kwargs["dim"]
  1002. new_kwargs["dim"] = _wrap_jagged_dim(
  1003. len(first.shape), dim, first._ragged_idx, "cat"
  1004. )
  1005. return NestedTensor(
  1006. func([t._values for t in tensors], **new_kwargs), **extract_kwargs(tensors[0])
  1007. )
  1008. @register_jagged_func(torch.ops.aten.matmul.default, "self: any, other: any")
  1009. def matmul_default(func, *args, **kwargs):
  1010. _, new_kwargs = normalize_function( # type: ignore[misc]
  1011. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1012. )
  1013. inp = new_kwargs.pop("input")
  1014. other = new_kwargs.pop("other")
  1015. def _unbind_impl(a, b):
  1016. return [
  1017. func(a_comp, b_comp) for (a_comp, b_comp) in zip(a.unbind(), b.unbind())
  1018. ]
  1019. def _padded_impl(a, b):
  1020. if a.is_nested:
  1021. nt = a
  1022. else:
  1023. nt = b
  1024. from .nested_tensor import nested_from_padded
  1025. min_seqlen = nt._maybe_min_seqlen
  1026. max_seqlen = nt._maybe_max_seqlen
  1027. padded_max_S = max_seqlen
  1028. total_L = nt._values.shape[nt._ragged_idx - 1]
  1029. if padded_max_S is None:
  1030. # use upper bound on max seqlen if it's not present
  1031. padded_max_S = total_L
  1032. padded_shape = (
  1033. *nt.shape[: nt._ragged_idx],
  1034. padded_max_S,
  1035. *nt.shape[nt._ragged_idx + 1 :],
  1036. )
  1037. padded_nt = nt.to_padded_tensor(0.0, output_size=padded_shape)
  1038. if a.is_nested:
  1039. padded_t = func(padded_nt, b)
  1040. else:
  1041. padded_t = func(a, padded_nt)
  1042. return nested_from_padded(
  1043. padded_t,
  1044. offsets=nt._offsets,
  1045. ragged_idx=nt._ragged_idx,
  1046. sum_S=total_L,
  1047. min_seqlen=min_seqlen,
  1048. max_seqlen=max_seqlen,
  1049. )
  1050. # TODO: Back these with proper kernels (e.g. grouped GEMM)
  1051. # NJT x dense
  1052. if inp.is_nested and not other.is_nested:
  1053. # (B, j1, D) x (B, D, E) => (B, j1, E)
  1054. if (
  1055. inp.dim() >= 3
  1056. and inp.dim() == other.dim()
  1057. and inp._ragged_idx < inp.dim() - 1
  1058. ):
  1059. # convert to padded for this
  1060. return _padded_impl(inp, other)
  1061. # Support broadcasting the dense:
  1062. # (B, j1, D) x (D, E) => (B, j1, E)
  1063. # (B, j1, D, E) x (E, F) => (B, j1, D, F)
  1064. # etc.
  1065. elif (
  1066. other.dim() == 2
  1067. and inp.dim() > other.dim()
  1068. and inp._ragged_idx < inp.dim() - 1
  1069. ):
  1070. return NestedTensor(
  1071. func(inp._values, other, **new_kwargs), **extract_kwargs(inp)
  1072. )
  1073. # Dense x NJT
  1074. elif not inp.is_nested and other.is_nested:
  1075. # (B, D, E) x (B, E, j1) => (B, E, j1)
  1076. if other.dim() >= 3 and other.dim() == inp.dim() and other._ragged_idx >= 2:
  1077. # convert to padded for this
  1078. return _padded_impl(inp, other)
  1079. # Support broadcasting the dense:
  1080. # (D, E) x (B, E, j1) => (B, D, j1)
  1081. # (D, E) x (B, E, j1, F) => (B, D, j1, F)
  1082. # etc.
  1083. elif inp.dim() == 2 and other.dim() > inp.dim() and other._ragged_idx >= 2:
  1084. return NestedTensor(
  1085. func(inp, other._values, **new_kwargs), **extract_kwargs(other)
  1086. )
  1087. # NJT x NJT
  1088. elif inp.is_nested and other.is_nested:
  1089. # Support ragged batch dim:
  1090. # (B, j1, D, E) x (B, j1, E, F) => (B, j1, D, F), etc.
  1091. if inp.dim() > 3 and other.dim() > 3 and raggedness_matches(inp, other._size):
  1092. return NestedTensor(func(inp._values, other._values), **extract_kwargs(inp))
  1093. # Support reducing over ragged with dense output:
  1094. # (B, D, j1) x (B, j1, E) => (B, D, E)
  1095. elif (
  1096. inp.dim() == 3
  1097. and other.dim() == 3
  1098. and inp._ragged_idx == 2
  1099. and other._ragged_idx == 1
  1100. and inp.size(inp._ragged_idx) == other.size(other._ragged_idx)
  1101. ):
  1102. # do unbind for this; can't use padded conversion due to j1 in last dim
  1103. return torch.stack(_unbind_impl(inp, other))
  1104. raise RuntimeError(
  1105. f"matmul(): not supported between inputs of shapes {inp._size} and {other.shape}"
  1106. )
  1107. @register_jagged_func(torch.ops.aten.bmm.default, "self: jt_all, mat2: any")
  1108. def bmm_default(func, *args, **kwargs):
  1109. _, new_kwargs = normalize_function( # type: ignore[misc]
  1110. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1111. )
  1112. inp = new_kwargs.pop("input")
  1113. other = new_kwargs.pop("mat2")
  1114. if inp.dim() != 3:
  1115. raise ValueError("bmm(): input must be 3D")
  1116. if other.dim() != 3:
  1117. raise ValueError("bmm(): mat2 must be 3D")
  1118. return matmul_default(torch.ops.aten.matmul.default, inp, other)
  1119. @register_jagged_func(
  1120. torch.ops.aten.expand.default, "self: jt_all, size: any, implicit: any?"
  1121. )
  1122. def expand_default(func, *args, **kwargs):
  1123. _, new_kwargs = normalize_function( # type: ignore[misc]
  1124. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1125. )
  1126. inp = new_kwargs.pop("input")
  1127. size = new_kwargs["size"]
  1128. if "implicit" in new_kwargs and new_kwargs.pop("implicit"):
  1129. raise AssertionError("implicit expand is not supported")
  1130. if not raggedness_matches(inp, size):
  1131. raise RuntimeError(f"expand(): cannot expand shape {inp._size} -> {size}")
  1132. expand_arg = [-1 if d == inp._ragged_idx else size[d] for d in range(1, inp.dim())]
  1133. return NestedTensor(func(inp._values, expand_arg), **extract_kwargs(inp))
  1134. @register_jagged_func(torch.ops.aten.expand_as.default, "self: t, other: jt")
  1135. def expand_as_default(func, *args, **kwargs):
  1136. _, new_kwargs = normalize_function( # type: ignore[misc]
  1137. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1138. )
  1139. inp = new_kwargs.pop("input")
  1140. other = new_kwargs.pop("other")
  1141. return NestedTensor(func(inp, other._values), **extract_kwargs(other))
  1142. @register_jagged_func(torch.ops.aten.broadcast_to.default, "self: jt_all, size: any")
  1143. def broadcast_to(func, *args, **kwargs):
  1144. _, new_kwargs = normalize_function( # type: ignore[misc]
  1145. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1146. )
  1147. inp = new_kwargs.pop("input")
  1148. size = new_kwargs.pop("size")
  1149. if len(size) <= inp.dim():
  1150. return inp.expand([*(1 for _ in range(inp.dim() - len(size))), *size])
  1151. raise ValueError(
  1152. "broadcast_to(): broadcasting to a higher-dim shape is currently not supported "
  1153. "for nested tensors with the jagged layout"
  1154. )
  1155. @register_jagged_func(torch.ops.aten.broadcast_tensors.default, "tensors: any")
  1156. def broadcast_tensors(func, *args, **kwargs):
  1157. _, new_kwargs = normalize_function( # type: ignore[misc]
  1158. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1159. )
  1160. tensors = new_kwargs.pop("tensors")
  1161. if len(tensors) == 0:
  1162. raise ValueError("broadcast_tensors(): expected at least one tensor input")
  1163. if len(tensors) == 1:
  1164. return tensors[0]
  1165. outs = []
  1166. broadcast_shape = torch.broadcast_shapes(*(t.shape for t in tensors))
  1167. # Pull out the first NJT. If broadcast_shapes() worked, the nested ints are compatible.
  1168. njt = next(t for t in tensors if isinstance(t, NestedTensor))
  1169. for t in tensors:
  1170. if t.is_nested:
  1171. outs.append(t.broadcast_to(broadcast_shape))
  1172. elif t.dim() < len(broadcast_shape):
  1173. outs.append(
  1174. NestedTensor(t.broadcast_to(njt._values.shape), **extract_kwargs(njt))
  1175. )
  1176. else:
  1177. raise ValueError(
  1178. "broadcast_tensors(): broadcasting nested tensors with dense tensors of equal "
  1179. "or higher dim is not currently supported"
  1180. )
  1181. return tuple(outs)
  1182. @register_jagged_func(
  1183. torch.ops.aten.where.self, "condition: jt_all, self: any, other: any"
  1184. )
  1185. def where_self(func, *args, **kwargs):
  1186. _, new_kwargs = normalize_function( # type: ignore[misc]
  1187. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1188. )
  1189. condition = new_kwargs.pop("condition")
  1190. inp = new_kwargs.pop("input")
  1191. other = new_kwargs.pop("other")
  1192. # if the tensors aren't compatible, broadcast_tensors() will let us know
  1193. condition, inp, other = torch.broadcast_tensors(condition, inp, other)
  1194. return NestedTensor(
  1195. func(condition._values, inp._values, other._values, **new_kwargs),
  1196. **extract_kwargs(condition),
  1197. )
  1198. @register_jagged_func(torch.ops.aten._pin_memory.default, "self: jt, device: any?")
  1199. def _pin_memory_default(func, *args, **kwargs):
  1200. _, new_kwargs = normalize_function( # type: ignore[misc]
  1201. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1202. )
  1203. inp = new_kwargs.pop("input")
  1204. return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
  1205. @register_jagged_func(torch.ops.aten.is_pinned.default, "self: jt, device: any?")
  1206. def is_pinned_default(func, *args, **kwargs):
  1207. _, new_kwargs = normalize_function( # type: ignore[misc]
  1208. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1209. )
  1210. inp = new_kwargs.pop("input")
  1211. return func(inp._values, **new_kwargs)
  1212. @register_jagged_func(
  1213. torch.ops.aten.is_same_size.default, "self: jt_all, other: jt_all"
  1214. )
  1215. def is_same_size_default(func, *args, **kwargs):
  1216. return args[0]._size == args[1]._size
  1217. def _apply_reduction(func, func_name, identity_element, *args, **kwargs):
  1218. _, new_kwargs = normalize_function( # type: ignore[misc]
  1219. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1220. )
  1221. inp = new_kwargs.pop("input")
  1222. # some ops use dim=None to indicate a full reduction; some use an empty dim list
  1223. full_reduction = new_kwargs["dim"] is None or (
  1224. isinstance(new_kwargs["dim"], (tuple, list)) and len(new_kwargs["dim"]) == 0
  1225. )
  1226. if full_reduction:
  1227. out = func(inp._values, **new_kwargs)
  1228. if new_kwargs.get("keepdim", False):
  1229. if isinstance(out, (tuple, list)):
  1230. # some ops return multiple things; unsqueeze all of them
  1231. out = type(out)(o.unsqueeze(inp._ragged_idx) for o in out)
  1232. else:
  1233. out = out.unsqueeze(inp._ragged_idx)
  1234. return out
  1235. # some ops support lists of dims; some don't
  1236. dim_to_convert = new_kwargs["dim"]
  1237. is_dimlist = isinstance(new_kwargs["dim"], (tuple, list))
  1238. if not is_dimlist:
  1239. dim_to_convert = [dim_to_convert]
  1240. (
  1241. converted_dim,
  1242. reduce_on_batch,
  1243. reduce_on_ragged,
  1244. reduce_on_non_batch,
  1245. ) = _wrap_jagged_dims(
  1246. inp.dim(),
  1247. dim_to_convert,
  1248. f"{func_name}",
  1249. inp._ragged_idx,
  1250. )
  1251. if not is_dimlist:
  1252. # convert back from list
  1253. converted_dim = converted_dim[0]
  1254. new_kwargs["dim"] = converted_dim
  1255. if reduce_on_ragged and inp._lengths is not None:
  1256. raise RuntimeError(
  1257. f"{func_name}(): reducing across the ragged dimension is not supported "
  1258. "for non-contiguous nested tensors with holes"
  1259. )
  1260. from torch.utils._pytree import tree_map
  1261. # raggedness reduced away --> return dense tensor
  1262. if reduce_on_ragged:
  1263. # reduction cases: (batch, ragged), (batch, ragged, non-batch), etc.
  1264. if reduce_on_batch:
  1265. # no need to read offsets --> apply sum directly on values
  1266. out = func(inp._values, **new_kwargs)
  1267. if new_kwargs.get("keepdim", False):
  1268. # some ops return multiple things; unsqueeze all of them
  1269. out = tree_map(lambda o: o.unsqueeze(0), out)
  1270. return out
  1271. else:
  1272. # invalid reduction cases: (ragged, non-batch), etc.
  1273. if reduce_on_non_batch:
  1274. raise RuntimeError(
  1275. f"{func_name}(): reducing along a ragged and non-batch dimension "
  1276. "is not supported for nested tensors"
  1277. )
  1278. # reduction cases: (ragged)
  1279. # convert to padded dense and reduce
  1280. new_kwargs.pop("dim")
  1281. dim_to_pass = [inp._ragged_idx] if is_dimlist else inp._ragged_idx
  1282. return func(
  1283. inp.to_padded_tensor(identity_element), dim=dim_to_pass, **new_kwargs
  1284. )
  1285. # raggedness preserved --> return nested tensor
  1286. else:
  1287. # invalid reduction cases: (batch), (batch, non-batch), etc.
  1288. if reduce_on_batch:
  1289. raise RuntimeError(
  1290. f"{func_name}(): reducing along the batch dimension but not "
  1291. "the ragged dimension is not supported for nested tensors"
  1292. )
  1293. # reduction cases: (non-batch), (non-batch, non-batch), etc.
  1294. # apply sum directly on values
  1295. out = func(inp._values, **new_kwargs)
  1296. out_kwargs = extract_kwargs(inp)
  1297. if not new_kwargs.get("keepdim", False):
  1298. # dims are reduced away -> ragged_idx of output needs to be reevaluated
  1299. dimlist = (
  1300. new_kwargs["dim"]
  1301. if isinstance(new_kwargs["dim"], (tuple, list))
  1302. else [new_kwargs["dim"]]
  1303. )
  1304. for d in dimlist:
  1305. # adjust for all dims reduced before the ragged dim
  1306. if d < inp._ragged_idx - 1:
  1307. out_kwargs["_ragged_idx"] -= 1
  1308. # some ops return multiple things; wrap each of them as an NJT
  1309. return tree_map(lambda o: NestedTensor(o, **out_kwargs), out)
  1310. @register_jagged_func(torch.ops.aten.sum.default, "self: jt_all, dtype: any?")
  1311. def sum_default(func, *args, **kwargs):
  1312. _, new_kwargs = normalize_function( # type: ignore[misc]
  1313. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1314. )
  1315. inp = new_kwargs.pop("input")
  1316. return func(inp._values, **new_kwargs)
  1317. @register_jagged_func(
  1318. torch.ops.aten.sum.dim_IntList,
  1319. "self: jt_all, dim: any?, keepdim: any?, dtype: any?",
  1320. )
  1321. def sum_dim_IntList(func, *args, **kwargs):
  1322. return _apply_reduction(func, "sum", 0, *args, **kwargs)
  1323. @register_jagged_func(
  1324. torch.ops.aten.transpose.int, "self: jt_all, dim0: any, dim1: any"
  1325. )
  1326. def transpose_int(func, *args, **kwargs):
  1327. _, new_kwargs = normalize_function( # type: ignore[misc]
  1328. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1329. )
  1330. from torch._prims_common import canonicalize_dims
  1331. inp = new_kwargs.pop("input")
  1332. dim0, dim1 = canonicalize_dims(inp.dim(), (new_kwargs["dim0"], new_kwargs["dim1"]))
  1333. # To support the SDPA API, inputs need to have the ragged idx transposed to dim 2
  1334. # instead of 1, although the internal Flash and mem-effn implementations will
  1335. # use the inputs with raggedness in dim 1.
  1336. if dim0 == inp._ragged_idx or dim1 == inp._ragged_idx:
  1337. if dim0 == 0 or dim1 == 0:
  1338. raise ValueError(
  1339. "Transpose is not supported on the batch dimension for jagged NT"
  1340. )
  1341. if dim0 == inp._ragged_idx:
  1342. to_dim = dim1
  1343. else:
  1344. to_dim = dim0
  1345. inp_kwargs = extract_kwargs(inp)
  1346. inp_kwargs["_ragged_idx"] = to_dim
  1347. return NestedTensor(
  1348. inp.values().transpose(
  1349. _outer_to_inner_dim(len(inp._size), dim0, inp._ragged_idx),
  1350. _outer_to_inner_dim(len(inp._size), dim1, inp._ragged_idx),
  1351. ),
  1352. **inp_kwargs,
  1353. )
  1354. new_kwargs["dim0"] = _wrap_jagged_dim(
  1355. inp.dim(), new_kwargs["dim0"], inp._ragged_idx, "transpose"
  1356. )
  1357. new_kwargs["dim1"] = _wrap_jagged_dim(
  1358. inp.dim(), new_kwargs["dim1"], inp._ragged_idx, "transpose"
  1359. )
  1360. return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
  1361. @register_jagged_func(torch.ops.aten.permute.default, "self: jt_all, dims: any")
  1362. def permute_default(func, *args, **kwargs):
  1363. _, new_kwargs = normalize_function( # type: ignore[misc]
  1364. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1365. )
  1366. inp = new_kwargs.pop("input")
  1367. dims = new_kwargs.pop("dims")
  1368. inp_kwargs = extract_kwargs(inp)
  1369. inp_dim = len(inp._size)
  1370. # The first two checks are the same as the checks in the normal permute implementation
  1371. if inp_dim != len(dims):
  1372. raise ValueError(
  1373. f"permute(): number of dimensions in the tensor input ({inp_dim}) "
  1374. + f"does not match the length of the desired ordering of dimensions ({len(dims)}).",
  1375. )
  1376. from torch._prims_common import canonicalize_dims
  1377. canonicalized_dims = canonicalize_dims(inp_dim, dims)
  1378. if len(canonicalized_dims) != len(set(canonicalized_dims)):
  1379. raise ValueError("permute(): duplicate dims are not allowed.")
  1380. if inp._lengths is not None:
  1381. raise ValueError(
  1382. "permute(): not supported on jagged layout nested tensor with holes"
  1383. )
  1384. if canonicalized_dims[0] != 0:
  1385. raise ValueError(
  1386. "Permute is not supported on the batch dimension for jagged NT"
  1387. )
  1388. inp_kwargs["_ragged_idx"] = canonicalized_dims.index(inp._ragged_idx)
  1389. inner_dims = [
  1390. _outer_to_inner_dim(inp_dim, dim, inp._ragged_idx)
  1391. for dim in canonicalized_dims[1:]
  1392. ]
  1393. new_kwargs["dims"] = inner_dims
  1394. return NestedTensor(func(inp._values, **new_kwargs), **inp_kwargs)
  1395. @register_jagged_func(
  1396. [torch.ops.aten.view.default, torch.ops.aten._unsafe_view.default],
  1397. "self: jt_all, size: any",
  1398. )
  1399. def view_default(func, *args, **kwargs):
  1400. _, new_kwargs = normalize_function( # type: ignore[misc]
  1401. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1402. )
  1403. inp = new_kwargs.pop("input")
  1404. size = new_kwargs.pop("size")
  1405. if inp._ragged_idx != 1 and tuple(inp._size) != tuple(size):
  1406. raise RuntimeError(
  1407. f"view(): does not support ragged_idx != 1 except when inp._size == size. "
  1408. f"inp._size is ({inp._size}) and size is ({size})."
  1409. )
  1410. # Ensure specified size still includes batch and ragged dims
  1411. if len(size) < 3 or not raggedness_matches(inp, size):
  1412. raise RuntimeError(f"view(): cannot view shape {inp._size} as {size}")
  1413. # outer size: the size of the NT, e.g. [3, j0, 10]
  1414. # inner size: the size of the values, e.g. [8, 10] (e.g. for offsets = [0, 3, 5, 8])
  1415. # this function gets inner_size[inner_idx] for a given inner_idx.
  1416. #
  1417. # example: for outer size [a, b, c, j0, d, e, f]
  1418. # assume that j0 is ragged, other are concrete integers
  1419. # and ragged_idx=3
  1420. # inner size will be [b, c, inp._values.size(ragged_idx), d, e, f]
  1421. # therefore:
  1422. # inner_size[0] = outer_size[1]
  1423. # inner_size[1] = outer_size[2]
  1424. # inner_size[0] = inp._values.size(ragged_idx - 1)
  1425. # inner_size[3] = outer_size[4]
  1426. # inner_size[4] = outer_size[5]
  1427. def get_inner_size(inner_idx):
  1428. nonlocal inp, size
  1429. if inner_idx == inp._ragged_idx - 1:
  1430. return inp._values.size(inner_idx)
  1431. else:
  1432. return size[inner_idx + 1]
  1433. inner_size = [get_inner_size(i) for i in range(len(size) - 1)]
  1434. # Preserve inference-mode-ness of input.
  1435. # TODO: Do this for all other views!
  1436. with torch.inference_mode(inp.is_inference()):
  1437. return NestedTensor(func(inp._values, inner_size), **extract_kwargs(inp))
  1438. @register_jagged_func(
  1439. torch.ops.aten.native_layer_norm.default,
  1440. "input: jt_all, normalized_shape: any, weight: any?, bias: any?, eps: any",
  1441. )
  1442. def native_layer_norm_default(func, *args, **kwargs):
  1443. _, new_kwargs = normalize_function( # type: ignore[misc]
  1444. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1445. )
  1446. inp = new_kwargs.pop("input")
  1447. if inp.dim() <= 2:
  1448. raise RuntimeError(
  1449. "layer_norm(): not supported for NestedTensor objects with 2 or fewer dimensions"
  1450. )
  1451. normalized_shape = new_kwargs["normalized_shape"]
  1452. ragged_size = inp.shape[inp._ragged_idx]
  1453. num_dims_not_normalized = inp.dim() - len(normalized_shape)
  1454. if (
  1455. num_dims_not_normalized == 0
  1456. ): # error if trying to normalize over the batch dimension
  1457. raise RuntimeError(
  1458. "layer_norm(): not supported when normalizing over the batch dimension for NestedTensor"
  1459. )
  1460. if ragged_size in normalized_shape and inp._lengths is not None:
  1461. raise RuntimeError(
  1462. "layer_norm(): not supported where lengths is not None if operating on the ragged dimension for NestedTensor"
  1463. )
  1464. if (
  1465. ragged_size in normalized_shape
  1466. ): # special handling for normalizing over the ragged dimension
  1467. padded_input = torch.ops.aten._jagged_to_padded_dense_forward(
  1468. inp._values.flatten(
  1469. start_dim=inp._ragged_idx
  1470. ), # _jagged_to_padded_dense_forward requires values to be a 2D tensor
  1471. [inp._offsets],
  1472. max_lengths=[inp._max_seqlen], # max length of ragged dimension
  1473. )
  1474. padded_mask = torch.ops.aten._jagged_to_padded_dense_forward(
  1475. torch.ones((inp._values.shape[0], 1), device=inp.device, dtype=inp.dtype),
  1476. [inp._offsets],
  1477. max_lengths=[inp._max_seqlen], # max length of ragged dimension
  1478. ).expand(
  1479. padded_input.shape
  1480. ) # mask elements outside of the ragged dimension and expand to the same shape as padded input (3D dense tensor)
  1481. ragged_lengths = (
  1482. inp._offsets.diff().unsqueeze(1).unsqueeze(1) * padded_input.shape[2]
  1483. ) # ragged dim * inner dim, since we sum over dims (1, 2) (the layer on which we normalize)
  1484. mean = (
  1485. torch.sum(
  1486. padded_input,
  1487. dim=(1, 2),
  1488. keepdim=True,
  1489. )
  1490. / ragged_lengths
  1491. ) # a sum over (1, 2) ensures layer norm, whereas a sum over (1) would be an instance norm
  1492. padded_normalized = (
  1493. (padded_input - mean) * padded_mask
  1494. ) # mask elements outside of the ragged dimension size for correct variance calculation
  1495. variance = (
  1496. torch.sum(
  1497. torch.square(padded_normalized),
  1498. dim=(1, 2),
  1499. keepdim=True,
  1500. )
  1501. / ragged_lengths
  1502. ) # a sum over (1, 2) ensures layer norm, whereas a sum over (1) would be an instance norm
  1503. std = torch.sqrt(variance + new_kwargs["eps"])
  1504. padded_layer_norm = padded_normalized / std
  1505. jagged_layer_norm_values = torch.ops.aten._padded_dense_to_jagged_forward(
  1506. padded_layer_norm,
  1507. [inp._offsets],
  1508. total_L=inp._values.shape[
  1509. 0
  1510. ], # providing this parameter helps avoid a GPU/CPU sync
  1511. ).unflatten(
  1512. -1, inp.shape[inp._ragged_idx + 1 :]
  1513. ) # unflatten last dimension back into original nested tensor shape, e.g. (B, *, WH) --> (B, *, W, H)
  1514. return (
  1515. NestedTensor(jagged_layer_norm_values, **extract_kwargs(inp)),
  1516. mean,
  1517. std,
  1518. )
  1519. output, mean, std = func(inp._values, **new_kwargs)
  1520. return (NestedTensor(output, **extract_kwargs(inp)), mean, std)
  1521. @register_jagged_func(
  1522. torch.ops.aten.native_layer_norm_backward.default,
  1523. "grad_out: jt, input: jt, normalized_shape: any, mean: any, rstd: any, weight: any?, bias: any?, output_mask: any",
  1524. )
  1525. def native_layer_norm_backward_default(func, *args, **kwargs):
  1526. _, new_kwargs = normalize_function( # type: ignore[misc]
  1527. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1528. )
  1529. grad_out = new_kwargs.pop("grad_out")
  1530. inp = new_kwargs.pop("input")
  1531. d_input, d_gamma, d_beta = func(grad_out._values, inp._values, **new_kwargs)
  1532. if d_input is None:
  1533. return (None, d_gamma, d_beta)
  1534. return (NestedTensor(d_input, **extract_kwargs(inp)), d_gamma, d_beta)
  1535. @register_jagged_func(torch.ops.aten.select.int, "self: jt_all, dim: any, index: any")
  1536. def select_int(func, *args, **kwargs):
  1537. _, new_kwargs = normalize_function( # type: ignore[misc]
  1538. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1539. )
  1540. inp = new_kwargs.pop("input")
  1541. new_kwargs["dim"], operating_on_batch = _wrap_jagged_dim(
  1542. inp.dim(), new_kwargs["dim"], inp._ragged_idx, "select", allow_batch_dim=True
  1543. )
  1544. # handle batch dim slicing via unbind() for now
  1545. # TODO: make this more efficient
  1546. if operating_on_batch:
  1547. return inp.unbind()[new_kwargs["index"]]
  1548. if inp._lengths is not None:
  1549. raise ValueError(
  1550. "select(): not yet supported on dim != 0 for non-contiguous nested tensor with holes"
  1551. )
  1552. # if selecting before the ragged dim, adjust output ragged_idx
  1553. out_kwargs = extract_kwargs(inp)
  1554. if new_kwargs["dim"] < inp._ragged_idx - 1:
  1555. out_kwargs["_ragged_idx"] -= 1
  1556. return NestedTensor(func(inp._values, **new_kwargs), **out_kwargs)
  1557. @register_jagged_func(
  1558. torch.ops.aten.slice.Tensor,
  1559. "self: jt, dim: any?, start: any?, end: any?, step: any?",
  1560. )
  1561. def slice_tensor(func, *args, **kwargs):
  1562. _, new_kwargs = normalize_function( # type: ignore[misc]
  1563. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1564. )
  1565. inp = new_kwargs.pop("input")
  1566. new_kwargs["dim"] = _wrap_jagged_dim(
  1567. inp.dim(), new_kwargs["dim"], inp._ragged_idx, "slice"
  1568. )
  1569. return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
  1570. @register_jagged_func(
  1571. torch.ops.aten.index_put.default,
  1572. "input: jt_all, indices: any, values: t, accumulate: any?",
  1573. )
  1574. @register_jagged_func(
  1575. torch.ops.aten.index_put_.default,
  1576. "input: jt_all, indices: any, values: t, accumulate: any?",
  1577. )
  1578. def index_put_(func, *args, **kwargs):
  1579. _, new_kwargs = normalize_function( # type: ignore[misc]
  1580. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1581. )
  1582. inp: NestedTensor = new_kwargs.pop("input")
  1583. # For index_put_ to work, we add together the indices of the ragged dimension
  1584. # and the batch dimension, adding the offsets of each ragged dimension to its
  1585. # indices
  1586. indices = new_kwargs.pop("indices")
  1587. if len(indices) > inp.dim():
  1588. raise AssertionError(
  1589. f"Too many indices: got {len(indices)} but tensor has {inp.dim()} dimensions"
  1590. )
  1591. if len(indices) < inp._ragged_idx + 1:
  1592. if not inp.is_contiguous():
  1593. raise RuntimeError(
  1594. "index_put(): If ragged dimension is not part of indices, this only works on contiguous NJTs"
  1595. )
  1596. # Ragged dim is NOT part of indices, we need to pad the nested tensor to apply func
  1597. from .nested_tensor import nested_from_padded
  1598. min_seqlen = inp._maybe_min_seqlen
  1599. max_seqlen = inp._maybe_max_seqlen
  1600. padded_max_S = max_seqlen
  1601. total_L = inp._values.shape[inp._ragged_idx - 1]
  1602. if padded_max_S is None:
  1603. # use upper bound on max seqlen if it's not present
  1604. padded_max_S = total_L
  1605. padded_shape = (
  1606. *inp.shape[: inp._ragged_idx],
  1607. padded_max_S,
  1608. *inp.shape[inp._ragged_idx + 1 :],
  1609. )
  1610. padded_inp = inp.to_padded_tensor(0.0, output_size=padded_shape)
  1611. new_njt = nested_from_padded(
  1612. func(padded_inp, indices, **new_kwargs),
  1613. offsets=inp._offsets,
  1614. ragged_idx=inp._ragged_idx,
  1615. sum_S=total_L,
  1616. min_seqlen=min_seqlen,
  1617. max_seqlen=max_seqlen,
  1618. )
  1619. if func is torch.ops.aten.index_put_.default:
  1620. inp._values.copy_(new_njt.values())
  1621. return inp
  1622. return new_njt
  1623. # We can run on the underlying values directly
  1624. # Validate indices
  1625. if inp.lengths() is None:
  1626. lengths = inp.offsets().diff()
  1627. else:
  1628. lengths = inp.lengths()
  1629. torch._assert_async(
  1630. # pyrefly: ignore [no-matching-overload]
  1631. torch.all(indices[inp._ragged_idx] < lengths),
  1632. "Some indices in the ragged dimension are out of bounds!",
  1633. )
  1634. # Recompute indices for _values
  1635. ragged_indices = inp.offsets()[indices[0]] + indices[inp._ragged_idx]
  1636. func_indices = (
  1637. # before ragged dim
  1638. indices[1 : inp._ragged_idx]
  1639. # ragged dim (combined with batch)
  1640. + [ragged_indices]
  1641. # after ragged dim
  1642. + indices[inp._ragged_idx + 1 :]
  1643. )
  1644. if func is torch.ops.aten.index_put_.default:
  1645. inp._values = func(inp._values, func_indices, **new_kwargs)
  1646. return inp
  1647. return NestedTensor(
  1648. func(inp._values, func_indices, **new_kwargs),
  1649. **extract_kwargs(inp),
  1650. )
  1651. @register_jagged_func(
  1652. torch.ops.aten.convolution.default,
  1653. "input: jt, weight: t, bias: t?, stride: any, padding: any, "
  1654. "dilation: any, transposed: any, output_padding: any, groups: any",
  1655. )
  1656. def convolution_default(func, *args, **kwargs):
  1657. _, new_kwargs = normalize_function( # type: ignore[misc]
  1658. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1659. )
  1660. inp = new_kwargs.pop("input")
  1661. return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
  1662. @register_jagged_func(
  1663. torch.ops.aten.mean.dim, "self: jt_all, dim: any?, keepdim: any?, dtype: any?"
  1664. )
  1665. def mean_dim(func, *args, **kwargs):
  1666. _, new_kwargs = normalize_function( # type: ignore[misc]
  1667. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1668. )
  1669. inp = new_kwargs["input"]
  1670. (_, reduce_on_batch, reduce_on_ragged, reduce_on_non_batch) = _wrap_jagged_dims(
  1671. inp.dim(),
  1672. new_kwargs["dim"],
  1673. "mean",
  1674. inp._ragged_idx,
  1675. )
  1676. if reduce_on_ragged and not reduce_on_batch:
  1677. if reduce_on_non_batch:
  1678. raise AssertionError(
  1679. "Cannot reduce on both ragged and non-batch dimensions without also reducing on batch"
  1680. )
  1681. # calculate an intermediate sum and leave the dim in for normalization purposes
  1682. keepdim = new_kwargs["keepdim"]
  1683. new_kwargs["keepdim"] = True
  1684. intermediate_sum = _apply_reduction(
  1685. torch.ops.aten.sum.dim_IntList, "mean", 0, **new_kwargs
  1686. )
  1687. # normalize by sequence lengths
  1688. lengths = inp._lengths if inp._lengths is not None else inp._offsets.diff()
  1689. for _ in range(intermediate_sum.dim() - 1):
  1690. lengths = lengths.unsqueeze(-1)
  1691. out = intermediate_sum / lengths
  1692. if not keepdim:
  1693. out = out.squeeze(inp._ragged_idx)
  1694. return out
  1695. # at this point, we're just redispatching on the values buffer
  1696. # since we expect it to be unused, specify a weird intermediate value to
  1697. # hopefully make errors obvious
  1698. intermediate_value = 0.42
  1699. return _apply_reduction(func, "mean", intermediate_value, **new_kwargs)
  1700. @register_jagged_func(torch.ops.aten.mean.default, "self: jt_all, dtype: any?")
  1701. def mean_default(func, *args, **kwargs):
  1702. _, new_kwargs = normalize_function( # type: ignore[misc]
  1703. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1704. )
  1705. inp = new_kwargs.pop("input")
  1706. return func(inp._values, **new_kwargs)
  1707. @register_jagged_func(torch.ops.aten.any.dims, "self: jt_all, dim: any?, keepdim: any?")
  1708. def any_dims(func, *args, **kwargs):
  1709. return _apply_reduction(func, "any", False, *args, **kwargs)
  1710. @register_jagged_func(torch.ops.aten.any.dim, "self: jt_all, dim: any, keepdim: any?")
  1711. def any_dim(func, *args, **kwargs):
  1712. _, new_kwargs = normalize_function( # type: ignore[misc]
  1713. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1714. )
  1715. # wrap dim in list to redispatch to dims overload
  1716. new_kwargs["dim"] = [new_kwargs["dim"]]
  1717. return any_dims(torch.ops.aten.any.dims, **new_kwargs)
  1718. @register_jagged_func(torch.ops.aten.all.dims, "self: jt_all, dim: any?, keepdim: any?")
  1719. def all_dims(func, *args, **kwargs):
  1720. return _apply_reduction(func, "all", True, *args, **kwargs)
  1721. @register_jagged_func(torch.ops.aten.all.dim, "self: jt_all, dim: any, keepdim: any?")
  1722. def all_dim(func, *args, **kwargs):
  1723. _, new_kwargs = normalize_function( # type: ignore[misc]
  1724. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1725. )
  1726. # wrap dim in list to redispatch to dims overload
  1727. new_kwargs["dim"] = [new_kwargs["dim"]]
  1728. return all_dims(torch.ops.aten.all.dims, **new_kwargs)
  1729. @register_jagged_func(
  1730. [
  1731. torch.ops.aten.all.default,
  1732. torch.ops.aten.any.default,
  1733. torch.ops.aten.max.default,
  1734. torch.ops.aten.min.default,
  1735. ],
  1736. "self: jt_all",
  1737. )
  1738. def all_any_max_min_default(func, *args, **kwargs):
  1739. _, new_kwargs = normalize_function( # type: ignore[misc]
  1740. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1741. )
  1742. inp = new_kwargs.pop("input")
  1743. return func(inp._values, **new_kwargs)
  1744. @register_jagged_func(
  1745. [torch.ops.aten._is_all_true.default, torch.ops.aten._is_any_true.default],
  1746. "self: jt_all",
  1747. )
  1748. def _is_true_default(func, *args, **kwargs):
  1749. _, new_kwargs = normalize_function( # type: ignore[misc]
  1750. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1751. )
  1752. inp = new_kwargs.pop("input")
  1753. return func(inp._values)
  1754. @register_jagged_func(torch.ops.aten.min.dim, "self: jt_all, dim: any, keepdim: any?")
  1755. def min_dim(func, *args, **kwargs):
  1756. _, new_kwargs = normalize_function( # type: ignore[misc]
  1757. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1758. )
  1759. dtype = new_kwargs["input"].dtype
  1760. dtype_max = _get_padding_value(dtype, "max")
  1761. return _apply_reduction(func, "min", dtype_max, *args, **kwargs)
  1762. @register_jagged_func(torch.ops.aten.max.dim, "self: jt_all, dim: any, keepdim: any?")
  1763. def max_dim(func, *args, **kwargs):
  1764. _, new_kwargs = normalize_function( # type: ignore[misc]
  1765. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1766. )
  1767. dtype = new_kwargs["input"].dtype
  1768. dtype_min = _get_padding_value(dtype, "min")
  1769. return _apply_reduction(func, "max", dtype_min, *args, **kwargs)
  1770. @register_jagged_func(
  1771. torch.ops.aten.amin.default, "self: jt_all, dim: any?, keepdim: any?"
  1772. )
  1773. def amin_default(func, *args, **kwargs):
  1774. _, new_kwargs = normalize_function( # type: ignore[misc]
  1775. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1776. )
  1777. dtype = new_kwargs["input"].dtype
  1778. dtype_max = _get_padding_value(dtype, "max")
  1779. return _apply_reduction(func, "amin", dtype_max, *args, **kwargs)
  1780. @register_jagged_func(
  1781. torch.ops.aten.amax.default, "self: jt_all, dim: any?, keepdim: any?"
  1782. )
  1783. def amax_default(func, *args, **kwargs):
  1784. _, new_kwargs = normalize_function( # type: ignore[misc]
  1785. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1786. )
  1787. dtype = new_kwargs["input"].dtype
  1788. dtype_min = _get_padding_value(dtype, "min")
  1789. return _apply_reduction(func, "amax", dtype_min, *args, **kwargs)
  1790. @register_jagged_func(
  1791. torch.ops.aten.argmin.default, "self: jt_all, dim: any?, keepdim: any?"
  1792. )
  1793. def argmin_default(func, *args, **kwargs):
  1794. _, new_kwargs = normalize_function( # type: ignore[misc]
  1795. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1796. )
  1797. dtype = new_kwargs["input"].dtype
  1798. dtype_max = _get_padding_value(dtype, "max")
  1799. return _apply_reduction(func, "argmin", dtype_max, *args, **kwargs)
  1800. @register_jagged_func(
  1801. torch.ops.aten.argmax.default, "self: jt_all, dim: any?, keepdim: any?"
  1802. )
  1803. def argmax_default(func, *args, **kwargs):
  1804. _, new_kwargs = normalize_function( # type: ignore[misc]
  1805. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1806. )
  1807. dtype = new_kwargs["input"].dtype
  1808. dtype_min = _get_padding_value(dtype, "min")
  1809. return _apply_reduction(func, "argmax", dtype_min, *args, **kwargs)
  1810. @register_jagged_func(
  1811. torch.ops.aten.value_selecting_reduction_backward.default,
  1812. "grad: jt_all, dim: any, indices: jt_all, sizes: any, keepdim: any",
  1813. )
  1814. def value_selecting_reduction_backward_default(func, *args, **kwargs):
  1815. from torch.fx.experimental.symbolic_shapes import is_nested_int
  1816. _, new_kwargs = normalize_function( # type: ignore[misc]
  1817. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1818. )
  1819. grad = new_kwargs.pop("grad")
  1820. new_kwargs["grad"] = grad._values
  1821. indices = new_kwargs.pop("indices")
  1822. new_kwargs["indices"] = indices._values
  1823. # should always succeed; sizes should contain a nested int
  1824. ragged_idx = next(i for i, s in enumerate(new_kwargs["sizes"]) if is_nested_int(s))
  1825. # convert dim -> values-space dim
  1826. new_kwargs["dim"] = _wrap_jagged_dim(
  1827. len(new_kwargs["sizes"]),
  1828. new_kwargs["dim"],
  1829. ragged_idx,
  1830. "value_selecting_reduction_backward",
  1831. )
  1832. # convert saved NJT sizes -> values-space sizes
  1833. sizes = new_kwargs.pop("sizes")
  1834. sizes[ragged_idx] = indices._values.size(indices._ragged_idx - 1)
  1835. sizes = sizes[1:]
  1836. new_kwargs["sizes"] = sizes
  1837. output_kwargs = extract_kwargs(indices)
  1838. output_kwargs["_ragged_idx"] = ragged_idx
  1839. return NestedTensor(func(**new_kwargs), **output_kwargs)
  1840. @register_jagged_func(torch.ops.aten.stack.default, "tensors: any, dim: any?")
  1841. def stack_default(func, *args, **kwargs):
  1842. _, new_kwargs = normalize_function( # type: ignore[misc]
  1843. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1844. )
  1845. # guaranteed this is non-empty if we got here
  1846. tensors = new_kwargs.pop("tensors")
  1847. for t in tensors:
  1848. if not isinstance(t, NestedTensor):
  1849. raise RuntimeError("stack(): expected all nested tensors inputs")
  1850. if t.dim() != tensors[0].dim():
  1851. raise RuntimeError(
  1852. "stack(): expected all nested tensors to have the same dim"
  1853. )
  1854. if not raggedness_matches(t, tensors[0].shape):
  1855. raise RuntimeError(
  1856. "stack(): expected all nested tensors to have the same nested structure"
  1857. )
  1858. new_kwargs["dim"] = _wrap_jagged_dim(
  1859. tensors[0].dim() + 1, new_kwargs["dim"], tensors[0]._ragged_idx, "stack"
  1860. )
  1861. return NestedTensor(
  1862. func([t._values for t in tensors], **new_kwargs), **extract_kwargs(tensors[0])
  1863. )
  1864. @register_jagged_func(
  1865. torch.ops.aten.embedding.default,
  1866. "weight: t, indices: jt, padding_idx: any?, scale_grad_by_freq: any?, sparse: any?",
  1867. )
  1868. def embedding_default(func, *args, **kwargs):
  1869. _, new_kwargs = normalize_function( # type: ignore[misc]
  1870. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1871. )
  1872. # guaranteed this is non-empty if we got here
  1873. indices = new_kwargs.pop("indices")
  1874. weight = new_kwargs.pop("weight")
  1875. return NestedTensor(
  1876. func(weight, indices._values, **new_kwargs), **extract_kwargs(indices)
  1877. )
  1878. @register_jagged_func(
  1879. torch.ops.aten.embedding_dense_backward.default,
  1880. "grad_output: jt, indices: jt, num_weights: any, padding_idx: any, scale_grad_by_freq: any",
  1881. )
  1882. def embedding_dense_backward_default(func, *args, **kwargs):
  1883. _, new_kwargs = normalize_function( # type: ignore[misc]
  1884. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1885. )
  1886. indices = new_kwargs.pop("indices")
  1887. grad_output = new_kwargs.pop("grad_output")
  1888. return func(grad_output._values, indices._values, **new_kwargs)
  1889. @register_jagged_func(
  1890. [
  1891. torch.ops.aten.values.default,
  1892. torch.ops.aten._nested_get_values.default,
  1893. ],
  1894. "self: jt_all",
  1895. )
  1896. def values_default(func, *args, **kwargs):
  1897. _, new_kwargs = normalize_function( # type: ignore[misc]
  1898. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1899. )
  1900. inp = new_kwargs.pop("input")
  1901. # TODO: Handle inference mode properly.
  1902. # See https://github.com/pytorch/pytorch/issues/112024#issuecomment-1779554292
  1903. return inp._values.detach()
  1904. @register_jagged_func(torch.ops.aten.all.default, "self: jt_all")
  1905. def all_default(func, *args, **kwargs):
  1906. _, new_kwargs = normalize_function( # type: ignore[misc]
  1907. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1908. )
  1909. inp = new_kwargs.pop("input")
  1910. return func(inp._values)
  1911. @register_jagged_func(
  1912. torch.ops.aten.to_padded_tensor.default,
  1913. "self: jt_all, padding: any, output_size: any?",
  1914. )
  1915. def to_padded_tensor_default(func, *args, **kwargs):
  1916. _, new_kwargs = normalize_function( # type: ignore[misc]
  1917. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1918. )
  1919. inp = new_kwargs.pop("input")
  1920. if inp._lengths is not None:
  1921. raise RuntimeError(
  1922. "to_padded_tensor(): not supported for nested tensors with holes"
  1923. )
  1924. # TODO: Handle the rest of output_size
  1925. output_size = new_kwargs["output_size"]
  1926. if output_size is not None:
  1927. max_seq_len = output_size[inp._ragged_idx]
  1928. else:
  1929. max_seq_len = (
  1930. inp._max_seqlen
  1931. if inp._max_seqlen_tensor is not None
  1932. else inp._values.size(0)
  1933. )
  1934. # only 2D values with ragged packed dim=0 is supported by the underlying FBGEMM
  1935. # kernel so do shape gymnastics if needed
  1936. values = inp.values()
  1937. if inp._ragged_idx > 1:
  1938. values = values.transpose(inp._ragged_idx - 1, 0)
  1939. values_shape = values.shape
  1940. if values.dim() > 2:
  1941. values = values.flatten(start_dim=1)
  1942. elif values.dim() == 1:
  1943. values = values.unsqueeze(-1)
  1944. # NB: The CUDA kernel for jagged -> padded dense conversion does not support
  1945. # integer / bool types; work around this by casting to half.
  1946. is_bool = values.dtype is torch.bool
  1947. if is_bool and values.is_cuda:
  1948. values = values.to(torch.half)
  1949. padded_out = torch.ops.aten._jagged_to_padded_dense_forward(
  1950. values,
  1951. [inp._offsets],
  1952. [max_seq_len],
  1953. new_kwargs["padding"],
  1954. )
  1955. if is_bool and padded_out.is_cuda:
  1956. padded_out = padded_out.to(torch.bool)
  1957. # shape gymnastics part 2
  1958. if len(values_shape) > 2:
  1959. padded_out = padded_out.unflatten(-1, values_shape[1:])
  1960. elif len(values_shape) == 1:
  1961. padded_out = padded_out.squeeze(-1)
  1962. if inp._ragged_idx > 1:
  1963. padded_out = padded_out.transpose(inp._ragged_idx, 1)
  1964. return padded_out
  1965. @register_jagged_func(
  1966. torch.ops.aten._nested_from_padded_tensor.default,
  1967. "padded: t, offsets: t, dummy: jt, ragged_idx: any?, min_seqlen: any?, max_seqlen: any?, sum_S: any?",
  1968. )
  1969. def _nested_from_padded_tensor_default(func, *args, **kwargs):
  1970. _, new_kwargs = normalize_function( # type: ignore[misc]
  1971. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1972. )
  1973. padded, offsets = new_kwargs["padded"], new_kwargs["offsets"]
  1974. ragged_idx = new_kwargs.get("ragged_idx", 1)
  1975. # only 3D padded with ragged packed dim=0 is supported by the underlying FBGEMM
  1976. # kernel so do shape gymnastics
  1977. if ragged_idx > 1:
  1978. padded = padded.transpose(ragged_idx, 1)
  1979. padded_ragged_dim1_shape = padded.shape
  1980. if padded.dim() > 3:
  1981. padded = padded.flatten(start_dim=2)
  1982. elif padded.dim() < 3:
  1983. padded = padded.unsqueeze(-1)
  1984. # NB: The CUDA kernel for padded dense -> jagged conversion does not support
  1985. # integer / bool types; work around this by casting to half.
  1986. is_bool = padded.dtype is torch.bool
  1987. if is_bool and padded.is_cuda:
  1988. padded = padded.to(torch.half)
  1989. values = torch.ops.aten._padded_dense_to_jagged_forward(
  1990. padded, [offsets], new_kwargs["sum_S"]
  1991. )
  1992. if is_bool and values.is_cuda:
  1993. values = values.to(torch.bool)
  1994. # shape gymnastics part 2
  1995. if len(padded_ragged_dim1_shape) > 3:
  1996. values = values.unflatten(-1, padded_ragged_dim1_shape[2:])
  1997. elif len(padded_ragged_dim1_shape) < 3:
  1998. values = values.squeeze(-1)
  1999. if ragged_idx > 1:
  2000. values = values.transpose(ragged_idx - 1, 0)
  2001. min_seqlen = new_kwargs["min_seqlen"]
  2002. max_seqlen = new_kwargs["max_seqlen"]
  2003. metadata_cache = {}
  2004. if min_seqlen is not None:
  2005. metadata_cache["min_seqlen"] = min_seqlen
  2006. if max_seqlen is not None:
  2007. metadata_cache["max_seqlen"] = max_seqlen
  2008. return NestedTensor(
  2009. values,
  2010. offsets,
  2011. _ragged_idx=ragged_idx,
  2012. _metadata_cache=metadata_cache,
  2013. )
  2014. @register_jagged_func(
  2015. torch.ops.aten._nested_view_from_jagged.default,
  2016. "values: t, offsets: t, dummy: jt_all, lengths: t?, ragged_idx: any?, min_seqlen: t?, max_seqlen: t?",
  2017. )
  2018. def _nested_view_from_jagged_default(func, *args, **kwargs):
  2019. _, new_kwargs = normalize_function( # type: ignore[misc]
  2020. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  2021. )
  2022. values, offsets, lengths = (
  2023. new_kwargs["input"],
  2024. new_kwargs["offsets"],
  2025. new_kwargs["lengths"],
  2026. )
  2027. ragged_idx = new_kwargs["ragged_idx"]
  2028. min_seqlen = new_kwargs["min_seqlen"]
  2029. max_seqlen = new_kwargs["max_seqlen"]
  2030. metadata_cache = {}
  2031. if min_seqlen is not None:
  2032. metadata_cache["min_seqlen"] = min_seqlen
  2033. if max_seqlen is not None:
  2034. metadata_cache["max_seqlen"] = max_seqlen
  2035. return NestedTensor(
  2036. values,
  2037. offsets,
  2038. lengths=lengths,
  2039. _ragged_idx=ragged_idx,
  2040. _metadata_cache=metadata_cache,
  2041. )
  2042. @register_jagged_func(torch.ops.aten._nested_get_offsets.default, "self: jt_all")
  2043. def _nested_get_offsets(func, *args, **kwargs):
  2044. _, new_kwargs = normalize_function( # type: ignore[misc]
  2045. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  2046. )
  2047. inp = new_kwargs.pop("input")
  2048. return inp._offsets
  2049. @register_jagged_func(torch.ops.aten._nested_get_lengths.default, "self: jt_all")
  2050. def _nested_get_lengths(func, *args, **kwargs):
  2051. _, new_kwargs = normalize_function( # type: ignore[misc]
  2052. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  2053. )
  2054. inp = new_kwargs.pop("input")
  2055. return inp._lengths
  2056. @register_jagged_func(torch.ops.aten._nested_get_ragged_idx.default, "self: jt_all")
  2057. def _nested_get_ragged_idx(func, *args, **kwargs):
  2058. _, new_kwargs = normalize_function( # type: ignore[misc]
  2059. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  2060. )
  2061. inp = new_kwargs.pop("input")
  2062. return inp._ragged_idx
  2063. @register_jagged_func(torch.ops.aten._nested_get_min_seqlen.default, "self: jt_all")
  2064. def _nested_get_min_seqlen(func, *args, **kwargs):
  2065. _, new_kwargs = normalize_function( # type: ignore[misc]
  2066. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  2067. )
  2068. inp = new_kwargs.pop("input")
  2069. return inp._metadata_cache.get("min_seqlen", None)
  2070. @register_jagged_func(torch.ops.aten._nested_get_max_seqlen.default, "self: jt_all")
  2071. def _nested_get_max_seqlen(func, *args, **kwargs):
  2072. _, new_kwargs = normalize_function( # type: ignore[misc]
  2073. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  2074. )
  2075. inp = new_kwargs.pop("input")
  2076. return inp._metadata_cache.get("max_seqlen", None)
  2077. # If a section of the Nested Tensor is fully masked out we still retain the section with a length of 0
  2078. @register_jagged_func(torch.ops.aten.masked_select.default, "self: jt, mask: any")
  2079. def masked_select_default(func, *args, **kwargs):
  2080. _, new_kwargs = normalize_function( # type: ignore[misc]
  2081. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  2082. )
  2083. inp = new_kwargs.pop("input")
  2084. mask = new_kwargs.pop("mask")
  2085. if inp.ndim > 2:
  2086. raise RuntimeError("masked_select only support 2-D selections currently")
  2087. elif inp.shape != mask.shape:
  2088. raise RuntimeError(
  2089. f"Mask with shape {mask.shape} is not compatible with input's shape {inp.shape}"
  2090. )
  2091. res_values = inp._values.masked_select(mask.values())
  2092. mask_cumsum = F.pad(mask.values().cumsum(dim=0), (1, 0)) # type: ignore[arg-type]
  2093. args = extract_kwargs(inp)
  2094. args["offsets"] = mask_cumsum[inp._offsets]
  2095. return NestedTensor(
  2096. values=res_values,
  2097. **args,
  2098. )
  2099. @register_jagged_func(
  2100. torch.ops.aten._nested_select_backward.default,
  2101. "grad_output: t, self: jt_all, dim: any, index: any",
  2102. )
  2103. def _nested_select_backward_default(func, *args, **kwargs):
  2104. _, new_kwargs = normalize_function( # type: ignore[misc]
  2105. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  2106. )
  2107. inp = new_kwargs.pop("input")
  2108. grad_output = new_kwargs.pop("grad_output")
  2109. grad_input = torch.zeros_like(inp, dtype=grad_output.dtype)
  2110. grad_input.select(new_kwargs["dim"], new_kwargs["index"]).copy_(grad_output)
  2111. return grad_input
  2112. @register_jagged_func(torch.ops.aten.record_stream.default, "self: jt_all, s: any")
  2113. def record_stream_default(func, *args, **kwargs) -> None:
  2114. inp = args[0]
  2115. stream = args[1]
  2116. # ensure all components live until stream computation completes
  2117. func(inp._values, stream)
  2118. func(inp._offsets, stream)
  2119. if inp._lengths is not None:
  2120. func(inp._lengths, stream)
  2121. @register_jagged_func(
  2122. [
  2123. torch.ops.aten.new_empty.default,
  2124. torch.ops.aten.new_zeros.default,
  2125. torch.ops.aten.new_ones.default,
  2126. ],
  2127. "self: jt_all, size: any, dtype: any?, layout: any?, device: any?, pin_memory: any?",
  2128. )
  2129. def new_empty_default(func, *args, **kwargs):
  2130. _, new_kwargs = normalize_function( # type: ignore[misc]
  2131. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  2132. )
  2133. inp = new_kwargs.pop("input")
  2134. if len(new_kwargs["size"]) == 0:
  2135. return func(inp._values, **new_kwargs)
  2136. raise RuntimeError("new_empty() not supported for NJT with shape != ()")
  2137. @register_jagged_func(
  2138. [
  2139. torch.ops.aten.elu_backward.default,
  2140. torch.ops.aten.hardshrink_backward.default,
  2141. torch.ops.aten.hardsigmoid_backward.default,
  2142. torch.ops.aten.hardtanh_backward.default,
  2143. torch.ops.aten.softplus_backward.default,
  2144. torch.ops.aten.softshrink_backward.default,
  2145. ],
  2146. "self: jt_all, ...",
  2147. )
  2148. def activation_backward(func, *args, **kwargs):
  2149. # first NJT arg is expected to be grad_output
  2150. grad_output = next(arg for arg in args if isinstance(arg, NestedTensor))
  2151. return NestedTensor(
  2152. func(
  2153. *(arg._values if isinstance(arg, NestedTensor) else arg for arg in args),
  2154. **kwargs,
  2155. ),
  2156. **extract_kwargs(grad_output),
  2157. )
  2158. @register_jagged_func(torch.ops.aten.fill.Scalar, "self: jt_all, value: any")
  2159. def fill_Scalar(func, *args, **kwargs):
  2160. _, new_kwargs = normalize_function( # type: ignore[misc]
  2161. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  2162. )
  2163. inp = new_kwargs.pop("input")
  2164. return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
  2165. @register_jagged_func(torch.ops.aten.fill_.Scalar, "self: jt_all, value: any")
  2166. def fill__Scalar(func, *args, **kwargs):
  2167. _, new_kwargs = normalize_function( # type: ignore[misc]
  2168. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  2169. )
  2170. inp = new_kwargs.pop("input")
  2171. func(inp._values, **new_kwargs)
  2172. return inp
  2173. @register_jagged_func(torch.ops.aten.frexp.Tensor, "self: jt_all")
  2174. def frexp_Tensor(func, *args, **kwargs):
  2175. _, new_kwargs = normalize_function( # type: ignore[misc]
  2176. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  2177. )
  2178. inp = new_kwargs.pop("input")
  2179. output_kwargs = extract_kwargs(inp)
  2180. mantissa, exponent = func(inp._values)
  2181. return NestedTensor(mantissa, **output_kwargs), NestedTensor(
  2182. exponent, **output_kwargs
  2183. )
  2184. @register_jagged_func(
  2185. torch.ops.aten.matmul_backward.default,
  2186. "grad: any, self: any, other: any, mask: any",
  2187. )
  2188. def matmul_backward_default(func, *args, **kwargs):
  2189. _, new_kwargs = normalize_function( # type: ignore[misc]
  2190. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  2191. )
  2192. grad = new_kwargs.pop("grad")
  2193. inp = new_kwargs.pop("input")
  2194. other = new_kwargs.pop("other")
  2195. grad_input_mask = new_kwargs.pop("mask")
  2196. if grad is None:
  2197. return (None, None)
  2198. grad_self = None
  2199. if grad_input_mask[0]:
  2200. grad_self = torch.matmul(grad, other.transpose(-1, -2))
  2201. grad_other = None
  2202. if grad_input_mask[1]:
  2203. grad_other = torch.matmul(inp.transpose(-1, -2), grad)
  2204. return (grad_self, grad_other)
  2205. # Make the dummy available on the C++ side.
  2206. @register_jagged_func(torch.ops.aten._nested_get_jagged_dummy.default, "self: any")
  2207. def _nested_get_jagged_dummy(func, *args, **kwargs):
  2208. from torch.nested._internal.nested_tensor import _nt_view_dummy
  2209. return _nt_view_dummy()
  2210. with torch.library._scoped_library("aten", "IMPL") as aten:
  2211. aten.impl("_nested_get_jagged_dummy", _nested_get_jagged_dummy, "CPU")
  2212. aten.impl("_nested_get_jagged_dummy", _nested_get_jagged_dummy, "CUDA")
  2213. aten.impl("_nested_get_jagged_dummy", _nested_get_jagged_dummy, "Meta")