param_fetch.py 3.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. from collections.abc import Callable
  2. from typing import Any
  3. import torch
  4. import torch.nn as nn
  5. from torch.fx._compatibility import compatibility
  6. from torch.fx.graph_module import GraphModule
  7. __all__ = [
  8. "default_matching",
  9. "extract_attrs_for_lowering",
  10. "lift_lowering_attrs_to_nodes",
  11. ]
  12. # Matching method matches the attribute name of current version to the attribute name of `target_version`
  13. @compatibility(is_backward_compatible=False)
  14. def default_matching(name: str, target_version: int) -> str:
  15. """Default matching method"""
  16. return name
  17. # This dict maps the nn.Module class name to the attribute name list that we want to fetch for lowering.
  18. # The first integer in the tuple is the version number of the nn.Module class when we create the parameter list.
  19. # If there's a version mismatch then it means the parameter names in the book might be mismatched with nn.Module.
  20. module_fetch_book: dict[type, tuple[int, list[str], Callable[[str, int], str]]] = {
  21. torch.nn.modules.linear.Linear: (1, ["weight", "bias"], default_matching),
  22. torch.nn.modules.conv.Conv2d: (
  23. 1,
  24. [
  25. "weight",
  26. "bias",
  27. "kernel_size",
  28. "stride",
  29. "padding",
  30. "dilation",
  31. "groups",
  32. "padding_mode",
  33. ],
  34. default_matching,
  35. ),
  36. torch.nn.modules.batchnorm.BatchNorm2d: (
  37. 2,
  38. ["weight", "bias", "running_mean", "running_var", "eps"],
  39. default_matching,
  40. ),
  41. torch.nn.modules.pooling.AdaptiveAvgPool2d: (1, [], default_matching),
  42. torch.nn.modules.pooling.MaxPool2d: (
  43. 1,
  44. ["kernel_size", "stride", "padding", "dilation", "return_indices", "ceil_mode"],
  45. default_matching,
  46. ),
  47. torch.nn.modules.activation.ReLU: (1, ["inplace"], default_matching),
  48. }
  49. @compatibility(is_backward_compatible=False)
  50. def extract_attrs_for_lowering(mod: nn.Module) -> dict[str, Any]:
  51. """If `mod` is in `module_fetch_book`, fetch the mod's attributes that in the `module_fetch_book`
  52. after checking module's version is compatible with the `module_fetch_book`.
  53. """
  54. attrs_for_lowering: dict[str, Any] = {}
  55. attrs_for_lowering["name"] = torch.typename(mod)
  56. if type(mod) in module_fetch_book:
  57. version, param_to_fetch, matching_method = module_fetch_book[type(mod)]
  58. if version < mod._version:
  59. raise RuntimeError(
  60. f"Fetcher version {version} try to fetch {torch.typename(mod)} version {mod._version}, "
  61. "please upgrade the module_fetch_book, open an issue and @842974287 "
  62. "or report a bug to AIACC team directly."
  63. )
  64. for attr in param_to_fetch:
  65. attrs_for_lowering[attr] = getattr(mod, matching_method(attr, mod._version))
  66. else:
  67. raise RuntimeError(
  68. f"{torch.typename(mod)} is not in the module_fetch_book yet, "
  69. "please add it to the module_fetch_book, open an issue and @842974287 "
  70. "or report a bug to AIACC team directly."
  71. )
  72. return attrs_for_lowering
  73. @compatibility(is_backward_compatible=False)
  74. def lift_lowering_attrs_to_nodes(fx_module: GraphModule) -> None:
  75. """Recursively traverse all `fx_module` nodes and fetch the module's attributes if the node is a leaf module."""
  76. submodules = dict(fx_module.named_modules())
  77. for node in fx_module.graph.nodes:
  78. if node.op == "call_module":
  79. if isinstance(submodules[node.target], GraphModule):
  80. lift_lowering_attrs_to_nodes(submodules[node.target])
  81. else:
  82. node.attrs_for_lowering = extract_attrs_for_lowering(
  83. submodules[node.target]
  84. )