_dynamism.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import re
  2. from collections.abc import Callable
  3. from typing import Any, Union
  4. import torch
  5. from torch.utils._pytree import tree_flatten_with_path, tree_map
  6. KeyPath = tuple[Any, ...]
  7. NonTensorShapeFn = Callable[[Union[int, float]], tuple[Any, ...]]
  8. __all__ = [
  9. "normalize_source_name",
  10. "module_to_nested_dict",
  11. "track_dynamism_across_examples",
  12. "clone_and_convert_to_meta",
  13. ]
  14. def normalize_source_name(name: str) -> str:
  15. # Match attribute access like .x and replace with ['x']
  16. return re.sub(r"\.([a-zA-Z_][a-zA-Z0-9_]*)", r"['\1']", name)
  17. def module_to_nested_dict(module: torch.nn.Module) -> dict[str, Any]:
  18. """Recursively converts an nn.Module into a nested dictionary with explicit 'parameters' and 'modules' keys."""
  19. self_dict: dict[str, Any] = {}
  20. self_dict["_parameters"] = {}
  21. self_dict["_modules"] = {}
  22. for attr_name in dir(module):
  23. try:
  24. if not attr_name.startswith("_") and not callable(
  25. getattr(module, attr_name)
  26. ):
  27. attr_value = getattr(module, attr_name)
  28. if (
  29. not isinstance(attr_value, torch.nn.Module)
  30. and isinstance(attr_value, (int, float, torch.Tensor))
  31. and type(attr_value) is not bool
  32. ):
  33. self_dict[attr_name] = attr_value
  34. except NotImplementedError:
  35. # Skip attributes that raise NotImplementedError since they won't
  36. # contain any dynamism anyways.
  37. continue
  38. for name, param in module.named_parameters(recurse=False):
  39. self_dict["_parameters"][name] = param
  40. for name, buffer in module.named_buffers(recurse=False):
  41. self_dict["_parameters"][name] = buffer
  42. for name, submodule in module.named_children():
  43. self_dict["_modules"][name] = module_to_nested_dict(submodule)
  44. return self_dict
  45. def track_dynamism_across_examples(
  46. example_inputs: list[Any],
  47. ) -> dict[Any, Any]:
  48. """
  49. This function analyzes a list of example inputs to determine the dynamism of their shapes.
  50. It tracks whether the dimensions of tensors or non-tensor values change across
  51. different examples. The function returns a dictionary where each key represents
  52. a path to a value in the input examples, and the corresponding value is a tuple
  53. indicating which dimensions are dynamic (i.e., change across examples). This
  54. helps in understanding how the structure of data varies across different instances.
  55. """
  56. tracking: dict[KeyPath, tuple[list[set[Any]], bool]] = {}
  57. for ex in example_inputs:
  58. if "self" in ex and isinstance(ex["self"], torch.nn.Module):
  59. ex["self"] = module_to_nested_dict(ex["self"])
  60. leaves_with_paths, _ = tree_flatten_with_path(ex)
  61. for key_path, value in leaves_with_paths:
  62. if not isinstance(value, (int, float, torch.Tensor)):
  63. continue
  64. if isinstance(value, torch.Tensor):
  65. shape: tuple[int | float, ...] = tuple(value.shape)
  66. is_tensor = True
  67. else:
  68. shape = (value,)
  69. is_tensor = False
  70. if key_path not in tracking:
  71. tracking[key_path] = ([set() for _ in range(len(shape))], is_tensor)
  72. else:
  73. dim_sets, flag = tracking[key_path]
  74. if flag != is_tensor:
  75. pass
  76. while len(dim_sets) < len(shape):
  77. dim_sets.append(set())
  78. for i, dim in enumerate(shape):
  79. tracking[key_path][0][i].add(dim)
  80. output: dict[Any, Any] = {}
  81. for key_path, (dim_sets, _is_tensor) in tracking.items():
  82. final_dyn = tuple(len(s) > 1 for s in dim_sets)
  83. key_str = "L" + "".join(f"{str(k)}" for k in key_path)
  84. key = key_path[0].key # type: ignore[attr-defined]
  85. if key not in output:
  86. output[key] = {}
  87. output[key][key_str] = final_dyn
  88. return output
  89. def clone_and_convert_to_meta(example_input: Any) -> Any:
  90. """
  91. This function takes a list of example inputs and for each tensor, clones it and converts it to device=meta.
  92. For non-tensor values, it keeps the reference. It uses pytree to handle nested structures recursively.
  93. """
  94. def transform_fn(value: Any) -> Any:
  95. if isinstance(value, torch.Tensor):
  96. return value.clone().to(device="meta")
  97. return value
  98. return tree_map(transform_fn, example_input)