iter.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620
  1. """
  2. This module provides iterator-related variable tracking functionality for Dynamo.
  3. It implements variable classes for handling Python iterators and itertools functions
  4. during symbolic execution and tracing.
  5. The module includes:
  6. - Base iterator variable classes for tracking iterator state
  7. - Implementations of built-in iterators (zip, map, filter)
  8. - Support for itertools functions (product, accumulate, combinations, etc.)
  9. - Mutation tracking and reconstruction capabilities for iterator operations
  10. These classes integrate with Dynamo's variable tracking system to enable proper
  11. handling of iterator operations during code transformation and optimization.
  12. """
  13. import itertools
  14. import sys
  15. from collections.abc import Callable, Sequence
  16. from typing import Any, TYPE_CHECKING, Union
  17. from .. import graph_break_hints, polyfills, variables
  18. from ..bytecode_transformation import (
  19. create_build_tuple,
  20. create_call_function,
  21. create_call_function_ex,
  22. create_instruction,
  23. )
  24. from ..exc import (
  25. handle_observed_exception,
  26. ObservedUserStopIteration,
  27. raise_observed_exception,
  28. unimplemented,
  29. UserError,
  30. )
  31. from .base import ValueMutationNew, VariableTracker
  32. from .constant import ConstantVariable
  33. if TYPE_CHECKING:
  34. from torch._dynamo.codegen import PyCodegen
  35. from torch._dynamo.symbolic_convert import InstructionTranslator
  36. MAX_ITERATOR_LIMIT = 100 * 1024 # 100k
  37. class ItertoolsVariable(VariableTracker):
  38. def __init__(self, value: Any, **kwargs: Any) -> None:
  39. super().__init__(**kwargs)
  40. self.value = value
  41. def __repr__(self) -> str:
  42. return f"ItertoolsVariable({self.value})"
  43. def as_python_constant(self) -> Any:
  44. return self.value
  45. def call_function(
  46. self,
  47. tx: "InstructionTranslator",
  48. args: Sequence["VariableTracker"],
  49. kwargs: "dict[str, VariableTracker]",
  50. ) -> "VariableTracker":
  51. # See also: module `torch._dynamo.polyfills.itertools`
  52. if self.value is itertools.product:
  53. if any(kw != "repeat" for kw in kwargs):
  54. unimplemented(
  55. gb_type="Unsupported kwargs for itertools.product",
  56. context=f"call_function {self} {args} {kwargs}",
  57. explanation=f"Expected kwargs: 'repeat', but got "
  58. f"{','.join(set(kwargs.keys()) - {'repeat'})}",
  59. hints=[*graph_break_hints.USER_ERROR],
  60. )
  61. if "repeat" in kwargs:
  62. r = kwargs["repeat"].as_python_constant()
  63. else:
  64. r = 1
  65. seqs = [arg.force_unpack_var_sequence(tx) for arg in args]
  66. items = [
  67. variables.TupleVariable(list(item))
  68. for item in itertools.product(*seqs, repeat=r)
  69. ]
  70. return variables.ListIteratorVariable(
  71. items, # type: ignore[arg-type]
  72. mutation_type=ValueMutationNew(),
  73. )
  74. elif (
  75. self.value is itertools.combinations
  76. and not kwargs
  77. and len(args) == 2
  78. and args[0].has_unpack_var_sequence(tx)
  79. and args[1].is_python_constant()
  80. ):
  81. iterable = args[0].unpack_var_sequence(tx)
  82. r = args[1].as_python_constant()
  83. items = []
  84. for item in itertools.combinations(iterable, r):
  85. items.append(variables.TupleVariable(list(item)))
  86. return variables.ListIteratorVariable(
  87. items, # type: ignore[arg-type]
  88. mutation_type=ValueMutationNew(),
  89. )
  90. elif self.value is itertools.groupby:
  91. if any(kw != "key" for kw in kwargs):
  92. unimplemented(
  93. gb_type="Unsupported kwargs for itertools.groupby",
  94. context=f"call_function {self} {args} {kwargs}",
  95. explanation=f"Expected kwargs: 'key', but got "
  96. f"{','.join(set(kwargs.keys()) - {'key'})}",
  97. hints=[*graph_break_hints.USER_ERROR],
  98. )
  99. def retrieve_const_key(key: VariableTracker) -> Any:
  100. if isinstance(key, variables.SymNodeVariable):
  101. return key.evaluate_expr()
  102. elif key.is_python_constant():
  103. return key.as_python_constant()
  104. else:
  105. unimplemented(
  106. gb_type="Unsupported key type for itertools.groupby",
  107. context=f"call_function {self} {args} {kwargs}",
  108. explanation="Dynamo does not know how to trace "
  109. f"itertools.groupby with key type: {str(type(key))}. "
  110. "We only support grouping keys that are constants (int, float, str, etc.)",
  111. hints=[*graph_break_hints.SUPPORTABLE],
  112. )
  113. if len(args) == 1 and args[0].has_unpack_var_sequence(tx):
  114. seq = args[0].unpack_var_sequence(tx)
  115. else:
  116. unimplemented(
  117. gb_type="Unsupported arguments for itertools.groupby",
  118. context=f"call_function {self} {args} {kwargs}",
  119. explanation="Dynamo does not know how to trace "
  120. f"itertools.groupby with args: {args} and kwargs: {kwargs}. "
  121. "itertools.groupby expects an iterable to group and an "
  122. "optional key function to determine groupings.",
  123. hints=[
  124. "Make sure the arguments to itertools.groupby are correct.",
  125. *graph_break_hints.SUPPORTABLE,
  126. ],
  127. )
  128. if "key" in kwargs:
  129. def keyfunc(x: VariableTracker) -> Any:
  130. return retrieve_const_key(
  131. kwargs.get("key").call_function(tx, [x], {}) # type: ignore[union-attr]
  132. )
  133. else:
  134. def keyfunc(x: VariableTracker) -> Any:
  135. return retrieve_const_key(x)
  136. result = []
  137. try:
  138. # pyrefly: ignore [unbound-name]
  139. for k, v in itertools.groupby(seq, key=keyfunc):
  140. result.append(
  141. variables.TupleVariable(
  142. [
  143. (
  144. variables.ConstantVariable.create(k)
  145. if variables.ConstantVariable.is_literal(k)
  146. else k
  147. ),
  148. variables.ListIteratorVariable(
  149. list(v), mutation_type=ValueMutationNew()
  150. ),
  151. ],
  152. mutation_type=ValueMutationNew(),
  153. )
  154. )
  155. except Exception as e:
  156. unimplemented(
  157. gb_type="Unexpected failure during itertools.groupby() iteration",
  158. context=f"call_function {self} {args} {kwargs}",
  159. explanation="Unexpected failure in invoking function during groupby",
  160. hints=[*graph_break_hints.SUPPORTABLE],
  161. from_exc=e,
  162. )
  163. return variables.ListIteratorVariable(
  164. result, # type: ignore[arg-type]
  165. mutation_type=ValueMutationNew(),
  166. )
  167. elif self.value is itertools.repeat:
  168. if len(args) < 2:
  169. return variables.RepeatIteratorVariable(
  170. *args, mutation_type=ValueMutationNew()
  171. )
  172. return tx.inline_user_function_return(
  173. VariableTracker.build(tx, polyfills.repeat), args, kwargs
  174. )
  175. elif self.value is itertools.count:
  176. return variables.CountIteratorVariable(
  177. *args, mutation_type=ValueMutationNew()
  178. )
  179. elif (
  180. self.value is itertools.permutations
  181. and (len(args) == 1 or (len(args) == 2 and args[1].is_python_constant()))
  182. and not kwargs
  183. ):
  184. if len(args) == 2:
  185. r = args[1].as_python_constant()
  186. else:
  187. r = None
  188. items = [
  189. variables.TupleVariable(list(item))
  190. for item in itertools.permutations(
  191. args[0].force_unpack_var_sequence(tx), r
  192. )
  193. ]
  194. return variables.ListIteratorVariable(
  195. items, # type: ignore[arg-type]
  196. mutation_type=ValueMutationNew(),
  197. )
  198. else:
  199. return super().call_function(tx, args, kwargs)
  200. class IteratorVariable(VariableTracker):
  201. def __init__(self, **kwargs: Any) -> None:
  202. super().__init__(**kwargs)
  203. def next_variable(self, tx: "InstructionTranslator") -> VariableTracker:
  204. unimplemented(
  205. gb_type="Unimplemented next() call",
  206. context=f"next({self})",
  207. explanation="This abstract method must be implemented",
  208. hints=[*graph_break_hints.DYNAMO_BUG],
  209. )
  210. # NOTE: only call when unpacking this iterator safely done eagerly!
  211. # Normally, iterators are accessed lazily.
  212. # Example of safe eager unpacking: list(map(f, seq))
  213. # Example of unsafe eager unpacking: list(islice(map(f, seq), 5))
  214. def force_unpack_var_sequence(
  215. self, tx: "InstructionTranslator"
  216. ) -> list[VariableTracker]:
  217. result: list[VariableTracker] = []
  218. self.force_apply_to_var_sequence(tx, result.append)
  219. return result
  220. def force_apply_to_var_sequence(
  221. self, tx: "InstructionTranslator", fn: Callable[[Any], Any]
  222. ) -> None:
  223. while True:
  224. try:
  225. fn(self.next_variable(tx))
  226. except ObservedUserStopIteration:
  227. handle_observed_exception(tx)
  228. break
  229. # don't call force_unpack_var_sequence since it can mutate
  230. # IteratorVariable state!
  231. def has_force_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool:
  232. return True
  233. def call_obj_hasattr(
  234. self, tx: "InstructionTranslator", name: str
  235. ) -> "ConstantVariable":
  236. if name == "__iter__" or name == "__next__":
  237. return variables.ConstantVariable.create(True)
  238. return super().call_obj_hasattr(tx, name)
  239. def call_method(
  240. self,
  241. tx: "InstructionTranslator",
  242. name: str,
  243. args: list[VariableTracker],
  244. kwargs: dict[str, VariableTracker],
  245. ) -> VariableTracker:
  246. if name == "__iter__":
  247. return self
  248. elif name == "__next__":
  249. return self.next_variable(tx)
  250. return super().call_method(tx, name, args, kwargs)
  251. class ObjectIteratorVariable(IteratorVariable):
  252. """
  253. VariableTracker for iter(obj) that implements the iterator protocol (i.e.,
  254. has a `__next__` method).
  255. We use this class to track the state of the iterator and handle the case
  256. when the iterator is exhausted:
  257. Example usage:
  258. > b = iter(obj)
  259. > list(b) # exhaust the iterator
  260. > list(b) # empty list
  261. """
  262. def __init__(self, obj: VariableTracker, **kwargs: Any) -> None:
  263. super().__init__(**kwargs)
  264. self.obj = obj
  265. self.generator_exhausted = False
  266. def next_variable(self, tx: "InstructionTranslator") -> VariableTracker:
  267. if self.generator_exhausted:
  268. raise_observed_exception(StopIteration, tx)
  269. try:
  270. return self.obj.next_variable(tx)
  271. except ObservedUserStopIteration:
  272. # Do not rely on the object to always return StopIteration once it
  273. # is exhausted.
  274. self.generator_exhausted = True
  275. raise
  276. class RepeatIteratorVariable(IteratorVariable):
  277. def __init__(self, item: VariableTracker, **kwargs: Any) -> None:
  278. super().__init__(**kwargs)
  279. self.item = item
  280. # Repeat needs no mutation, clone self
  281. def next_variable(self, tx: "InstructionTranslator") -> VariableTracker:
  282. return self.item
  283. def reconstruct(self, codegen: "PyCodegen") -> None:
  284. codegen.add_push_null(
  285. lambda: codegen.extend_output(
  286. [
  287. codegen.create_load_python_module(itertools),
  288. codegen.create_load_attr("repeat"),
  289. ]
  290. )
  291. )
  292. codegen(self.item)
  293. codegen.extend_output(create_call_function(1, False))
  294. class CountIteratorVariable(IteratorVariable):
  295. def __init__(
  296. self,
  297. item: Union[int, VariableTracker] = 0,
  298. step: Union[int, VariableTracker] = 1,
  299. **kwargs: Any,
  300. ) -> None:
  301. super().__init__(**kwargs)
  302. if not isinstance(item, VariableTracker):
  303. item = ConstantVariable.create(item)
  304. if not isinstance(step, VariableTracker):
  305. step = ConstantVariable.create(step)
  306. self.item = item
  307. self.step = step
  308. def next_variable(self, tx: "InstructionTranslator") -> VariableTracker:
  309. assert self.is_mutable()
  310. old_item = self.item
  311. tx.output.side_effects.mutation(self)
  312. self.item = self.item.call_method(tx, "__add__", [self.step], {})
  313. return old_item
  314. def reconstruct(self, codegen: "PyCodegen") -> None:
  315. codegen.add_push_null(
  316. lambda: codegen.extend_output(
  317. [
  318. codegen.create_load_python_module(itertools),
  319. codegen.create_load_attr("count"),
  320. ]
  321. )
  322. )
  323. codegen(self.item)
  324. codegen(self.step)
  325. codegen.extend_output(create_call_function(2, False))
  326. class ZipVariable(IteratorVariable):
  327. """
  328. Represents zip(*iterables)
  329. """
  330. _nonvar_fields = {
  331. "index",
  332. "strict",
  333. *IteratorVariable._nonvar_fields,
  334. }
  335. def __init__(
  336. self,
  337. iterables: list[VariableTracker],
  338. strict: bool = False,
  339. **kwargs: Any,
  340. ) -> None:
  341. super().__init__(**kwargs)
  342. assert isinstance(iterables, list)
  343. # can be list[Variable] or VariableTracker (with next_variable implemented)
  344. self.iterables = iterables
  345. self.index = 0
  346. self.strict = strict
  347. def python_type(self) -> type[zip]: # type: ignore[type-arg]
  348. return zip
  349. def has_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool:
  350. return all(
  351. isinstance(it, list) or it.has_unpack_var_sequence(tx)
  352. for it in self.iterables
  353. )
  354. def unpack_var_sequence(
  355. self, tx: "InstructionTranslator"
  356. ) -> list["VariableTracker"]:
  357. assert self.has_unpack_var_sequence(tx)
  358. iterables = []
  359. for it in self.iterables:
  360. if isinstance(it, list):
  361. iterables.append(it[self.index :])
  362. else:
  363. iterables.append(it.unpack_var_sequence(tx))
  364. kwargs = {"strict": self.strict} if self.strict else {}
  365. zipped = zip(*iterables, **kwargs)
  366. return [variables.TupleVariable(list(var)) for var in zipped]
  367. def next_variable(self, tx: "InstructionTranslator") -> VariableTracker:
  368. assert self.is_mutable()
  369. if len(self.iterables) == 0:
  370. raise_observed_exception(StopIteration, tx)
  371. old_index = self.index
  372. args = []
  373. def get_item(
  374. it: Union[list[VariableTracker], VariableTracker],
  375. ) -> VariableTracker:
  376. if isinstance(it, list):
  377. if old_index >= len(it):
  378. raise_observed_exception(StopIteration, tx)
  379. return it[old_index]
  380. else:
  381. return it.next_variable(tx)
  382. idx: int | None = None
  383. try:
  384. for idx, it in enumerate(self.iterables): # noqa:B007
  385. args.append(get_item(it))
  386. except ObservedUserStopIteration:
  387. if self.strict:
  388. if idx == 0:
  389. # all other iterables should be exhausted
  390. for it in self.iterables:
  391. try:
  392. get_item(it)
  393. except ObservedUserStopIteration:
  394. handle_observed_exception(tx)
  395. continue
  396. # no ObservedUserStopIteration - fall through to UserError
  397. break
  398. else:
  399. # all iterables exhausted, raise original error
  400. raise
  401. handle_observed_exception(tx)
  402. raise UserError(
  403. ValueError, # type: ignore[arg-type]
  404. "zip() has one argument of len differing from others",
  405. ) from None
  406. raise
  407. tx.output.side_effects.mutation(self)
  408. self.index += 1
  409. return variables.TupleVariable(args)
  410. def reconstruct_items(self, codegen: "PyCodegen") -> None:
  411. for it in self.iterables:
  412. if isinstance(it, list):
  413. remaining_items = it[self.index :]
  414. codegen.foreach(remaining_items)
  415. codegen.append_output(create_build_tuple(len(remaining_items)))
  416. else:
  417. codegen(it)
  418. def reconstruct(self, codegen: "PyCodegen") -> None:
  419. codegen.add_push_null(
  420. lambda: codegen.load_import_from("builtins", "zip"), call_function_ex=True
  421. )
  422. self.reconstruct_items(codegen)
  423. codegen.append_output(create_build_tuple(len(self.iterables)))
  424. codegen.extend_output(
  425. [
  426. codegen.create_load_const("strict"),
  427. codegen.create_load_const(self.strict),
  428. create_instruction("BUILD_MAP", arg=1),
  429. *create_call_function_ex(True, False),
  430. ]
  431. )
  432. class MapVariable(ZipVariable):
  433. """
  434. Represents map(fn, *iterables)
  435. """
  436. def __init__(
  437. self,
  438. fn: VariableTracker,
  439. iterables: list[VariableTracker],
  440. **kwargs: Any,
  441. ) -> None:
  442. super().__init__(iterables, **kwargs)
  443. self.fn = fn
  444. def python_type(self) -> type:
  445. return map
  446. def has_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool:
  447. return False
  448. def next_variable(self, tx: "InstructionTranslator") -> VariableTracker:
  449. args = super().next_variable(tx)
  450. return self.fn.call_function(tx, args.items, {}) # type: ignore[attr-defined]
  451. def reconstruct(self, codegen: "PyCodegen") -> None:
  452. codegen.add_push_null(
  453. lambda: codegen.load_import_from("builtins", "map"), call_function_ex=True
  454. )
  455. codegen(self.fn)
  456. self.reconstruct_items(codegen)
  457. codegen.append_output(create_build_tuple(len(self.iterables) + 1))
  458. if self.strict:
  459. assert sys.version_info >= (3, 14), (
  460. "Unexpected bug: map(strict=True) requires Python 3.14+"
  461. )
  462. codegen.extend_output(
  463. [
  464. codegen.create_load_const("strict"),
  465. codegen.create_load_const(self.strict),
  466. create_instruction("BUILD_MAP", arg=1),
  467. *create_call_function_ex(True, False),
  468. ]
  469. )
  470. else:
  471. codegen.extend_output(create_call_function_ex(False, False))
  472. class FilterVariable(IteratorVariable):
  473. """
  474. Represents filter(fn, iterable)
  475. """
  476. _nonvar_fields = {
  477. "index",
  478. *IteratorVariable._nonvar_fields,
  479. }
  480. def __init__(
  481. self,
  482. fn: VariableTracker,
  483. iterable: list[VariableTracker],
  484. **kwargs: Any,
  485. ) -> None:
  486. super().__init__(**kwargs)
  487. self.fn = fn
  488. self.iterable = iterable
  489. self.index = 0
  490. def python_type(self) -> type:
  491. return filter
  492. def has_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool:
  493. return isinstance(self.iterable, list) or self.iterable.has_unpack_var_sequence(
  494. tx
  495. )
  496. def unpack_var_sequence(
  497. self, tx: "InstructionTranslator"
  498. ) -> list["VariableTracker"]:
  499. assert self.has_unpack_var_sequence(tx)
  500. it = None
  501. if isinstance(self.iterable, list):
  502. it = self.iterable[self.index :]
  503. else:
  504. it = self.iterable.unpack_var_sequence(tx)
  505. filtered = self.fn.call_function(tx, it, {})
  506. return [variables.TupleVariable([filtered])]
  507. def next_variable(self, tx: "InstructionTranslator") -> VariableTracker:
  508. def _next() -> VariableTracker:
  509. old_index = self.index
  510. if isinstance(self.iterable, list):
  511. if old_index >= len(self.iterable):
  512. raise_observed_exception(StopIteration, tx)
  513. return self.iterable[old_index]
  514. else:
  515. return self.iterable.next_variable(tx)
  516. # A do-while loop to find elements that make fn return true
  517. while True:
  518. item = _next()
  519. self.index += 1
  520. if self.fn.is_constant_none():
  521. res = item
  522. else:
  523. res = self.fn.call_function(tx, [item], {})
  524. pred_res = variables.UserFunctionVariable(
  525. polyfills.predicate # type: ignore[arg-type]
  526. ).call_function(tx, [res], {})
  527. if pred_res.as_python_constant():
  528. return item
  529. def reconstruct_items(self, codegen: "PyCodegen") -> None:
  530. if isinstance(self.iterable, list):
  531. remaining_items = self.iterable[self.index :]
  532. codegen.foreach(remaining_items)
  533. codegen.append_output(create_build_tuple(len(remaining_items)))
  534. else:
  535. codegen(self.iterable)
  536. def reconstruct(self, codegen: "PyCodegen") -> None:
  537. codegen.add_push_null(lambda: codegen.load_import_from("builtins", "filter"))
  538. codegen(self.fn)
  539. self.reconstruct_items(codegen)
  540. codegen.extend_output(create_call_function(2, False))