_debug.py 608 B

12345678910111213141516171819202122
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. import torch
  3. from torch.fx.node import Argument
  4. def friendly_debug_info(v: object) -> Argument:
  5. """
  6. Helper function to print out debug info in a friendly way.
  7. """
  8. if isinstance(v, torch.Tensor):
  9. return f"Tensor({v.shape}, grad={v.requires_grad}, dtype={v.dtype})"
  10. else:
  11. return str(v)
  12. def map_debug_info(a: Argument) -> Argument:
  13. """
  14. Helper function to apply `friendly_debug_info` to items in `a`.
  15. `a` may be a list, tuple, or dict.
  16. """
  17. return torch.fx.node.map_aggregate(a, friendly_debug_info)