fx.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. from collections.abc import Callable
  2. from typing import Any
  3. from torch._C import _fx_map_aggregate, _fx_map_arg
  4. from torch.fx.immutable_collections import immutable_dict, immutable_list
  5. from torch.fx.node import Node
  6. from ..decorators import substitute_in_graph
  7. @substitute_in_graph(_fx_map_arg, can_constant_fold_through=True)
  8. def map_arg(a: Any, fn: Callable[[Node], Any]) -> Any:
  9. return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x)
  10. @substitute_in_graph(_fx_map_aggregate, can_constant_fold_through=True)
  11. def map_aggregate(a: Any, fn: Callable[[Any], Any]) -> Any:
  12. result: Any
  13. if isinstance(a, tuple):
  14. it = (map_aggregate(elem, fn) for elem in a)
  15. # Support NamedTuple (if it has `_fields`) by repacking into original type.
  16. result = type(a)(*it) if hasattr(a, "_fields") else tuple(it)
  17. elif isinstance(a, list):
  18. result = immutable_list([map_aggregate(elem, fn) for elem in a])
  19. elif isinstance(a, dict):
  20. result = immutable_dict([(k, map_aggregate(v, fn)) for k, v in a.items()])
  21. elif isinstance(a, slice):
  22. result = slice(
  23. map_aggregate(a.start, fn),
  24. map_aggregate(a.stop, fn),
  25. map_aggregate(a.step, fn),
  26. )
  27. else:
  28. result = fn(a)
  29. return result
  30. __all__ = [
  31. "map_arg",
  32. "map_aggregate",
  33. ]