shape_infer_helper.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. import logging
  6. import os
  7. import sys
  8. # In ORT Package the symbolic_shape_infer.py is in ../tools
  9. file_path = os.path.dirname(__file__)
  10. if os.path.exists(os.path.join(file_path, "../tools/symbolic_shape_infer.py")):
  11. sys.path.append(os.path.join(file_path, "../tools"))
  12. else:
  13. sys.path.append(os.path.join(file_path, ".."))
  14. from symbolic_shape_infer import SymbolicShapeInference, get_shape_from_type_proto, sympy # noqa: E402
  15. logger = logging.getLogger(__name__)
  16. class SymbolicShapeInferenceHelper(SymbolicShapeInference):
  17. def __init__(self, model, verbose=0, int_max=2**31 - 1, auto_merge=True, guess_output_rank=False):
  18. super().__init__(int_max, auto_merge, guess_output_rank, verbose)
  19. self.model_ = model
  20. self.all_shapes_inferred_: bool = False
  21. self.is_inferred_: bool = False
  22. self.dynamic_axis_mapping_: dict[str, int] = {}
  23. def infer(self, dynamic_axis_mapping: dict[str, int], max_runs: int = 200):
  24. """Run shape inference, and try replace dynamic axis from string to integer when mapping is provided.
  25. Args:
  26. dynamic_axis_mapping (_type_): a dictionary with name of dynamic axis as key, like {"batch_size" : 4}
  27. max_runs (int, optional): limit maximum number of runs to avoid infinite loop. Defaults to 200.
  28. Returns:
  29. bool: whether all shapes has been inferred or not.
  30. """
  31. assert dynamic_axis_mapping is not None
  32. if self.is_inferred_ and self.dynamic_axis_mapping_ == dynamic_axis_mapping:
  33. return self.all_shapes_inferred_
  34. self.dynamic_axis_mapping_ = dynamic_axis_mapping
  35. self._preprocess(self.model_)
  36. count = 0
  37. while self.run_:
  38. logger.debug(f"shape infer run {count}")
  39. self.all_shapes_inferred_ = self._infer_impl()
  40. count += 1
  41. if max_runs > 0 and count >= max_runs:
  42. break
  43. self.is_inferred_ = True
  44. return self.all_shapes_inferred_
  45. def _get_sympy_shape(self, node, idx):
  46. """Override it to ensure shape inference by giving the actual value of dynamic axis."""
  47. sympy_shape = []
  48. shape = self._get_shape(node, idx)
  49. if shape:
  50. for dim in shape:
  51. if isinstance(dim, str):
  52. if dim in self.dynamic_axis_mapping_:
  53. sympy_shape.append(self.dynamic_axis_mapping_[dim])
  54. elif dim in self.symbolic_dims_:
  55. sympy_shape.append(self.symbolic_dims_[dim])
  56. else:
  57. sympy_shape.append(sympy.Symbol(dim, integer=True))
  58. else:
  59. assert dim is not None
  60. sympy_shape.append(dim)
  61. return sympy_shape
  62. def get_edge_shape(self, edge):
  63. """Get shape of an edge.
  64. Args:
  65. edge (str): name of edge
  66. Returns:
  67. Optional[List[int]]: the shape, or None if shape is unknown
  68. """
  69. assert self.all_shapes_inferred_
  70. if edge not in self.known_vi_:
  71. print("Cannot retrieve the shape of " + str(edge))
  72. return None
  73. type_proto = self.known_vi_[edge].type
  74. shape = get_shape_from_type_proto(type_proto)
  75. if shape is not None:
  76. for i, dim in enumerate(shape):
  77. if isinstance(dim, str) and dim in self.dynamic_axis_mapping_:
  78. shape[i] = self.dynamic_axis_mapping_[dim]
  79. return shape
  80. def compare_shape(self, edge, edge_other):
  81. """Compare shape of two edges.
  82. Args:
  83. edge (str): name of edge
  84. edge_other (str): name of another edge
  85. Raises:
  86. Exception: At least one shape is missed for edges to compare
  87. Returns:
  88. bool: whether the shape is same or not
  89. """
  90. assert self.all_shapes_inferred_
  91. shape = self.get_edge_shape(edge)
  92. shape_other = self.get_edge_shape(edge_other)
  93. if shape is None or shape_other is None:
  94. raise Exception("At least one shape is missed for edges to compare")
  95. return shape == shape_other