unify_refinements.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. # mypy: allow-untyped-defs
  2. from torch.fx.experimental.graph_gradual_typechecker import Refine
  3. from torch.fx.experimental.unification import unify, Var # type: ignore[attr-defined]
  4. from torch.fx.tensor_type import TensorType
  5. def infer_symbolic_types_single_pass(traced):
  6. """
  7. Calls our symbolic inferencer once.
  8. """
  9. r = Refine(traced)
  10. r.refine()
  11. mgu = unify_eq(r.constraints)
  12. substitute_all_types(traced.graph, mgu)
  13. def infer_symbolic_types(traced):
  14. """
  15. Calls our symbolic inferencer twice.
  16. This is useful when one pass is not enough
  17. to infer all the information such as the case
  18. for broadcasting.
  19. """
  20. r = Refine(traced)
  21. r.refine()
  22. mgu = unify_eq(r.constraints)
  23. substitute_all_types(traced.graph, mgu)
  24. r = Refine(traced)
  25. r.refine()
  26. mgu = unify_eq(r.constraints)
  27. substitute_all_types(traced.graph, mgu)
  28. r.symbolic_relations()
  29. def convert_eq(list_of_eq):
  30. """
  31. Convert equality constraints in the right format
  32. to be used by unification library.
  33. """
  34. lhs = []
  35. rhs = []
  36. for eq in list_of_eq:
  37. lhs.append(eq.lhs)
  38. rhs.append(eq.rhs)
  39. return tuple(lhs), tuple(rhs)
  40. def unify_eq(list_of_eq):
  41. """
  42. Apply unification to a set of
  43. equality constraints
  44. """
  45. lhs, rhs = convert_eq(list_of_eq)
  46. return unify(lhs, rhs)
  47. def substitute_solution_one_type(mapping, t):
  48. """
  49. Apply the most general unifier to a type
  50. """
  51. if isinstance(t, Var):
  52. if t in mapping:
  53. return mapping[t]
  54. else:
  55. return t
  56. elif isinstance(t, TensorType):
  57. new_type = []
  58. for typ in t.__args__:
  59. if typ in mapping:
  60. new_type.append(mapping[typ])
  61. else:
  62. new_type.append(typ)
  63. return TensorType(tuple(new_type))
  64. elif isinstance(t, list):
  65. new_type = []
  66. for typ in t:
  67. new_type.append(substitute_solution_one_type(mapping, typ))
  68. return new_type
  69. elif isinstance(t, tuple):
  70. new_type = []
  71. for typ in t:
  72. new_type.append(substitute_solution_one_type(mapping, typ))
  73. return tuple(new_type)
  74. else:
  75. return t
  76. def substitute_all_types(graph, mapping):
  77. """
  78. Apply the most general unifier to all types in a graph
  79. till reaching a fixed point. If the input and output graph
  80. are the same, we converge.
  81. """
  82. flag = True
  83. while flag:
  84. flag = False
  85. for k in mapping:
  86. old_mapping_val = mapping[k]
  87. if mapping[k] in mapping:
  88. new_key = mapping[k]
  89. mapping[k] = mapping[new_key]
  90. if old_mapping_val != mapping[k]:
  91. flag = True
  92. for n in graph.nodes:
  93. n.type = substitute_solution_one_type(mapping, n.type)
  94. def check_for_type_equality(g1, g2):
  95. """
  96. A check equality to be used in fixed points.
  97. We do not use graph equality but instead type
  98. equality.
  99. """
  100. for n, m in zip(g1.nodes, g2.nodes):
  101. if n.type != m.type:
  102. return False
  103. return True