offline_tuning.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. # Copyright (c) Microsoft Corporation. All rights reserved.
  2. # Licensed under the MIT License.
  3. import argparse
  4. import copy
  5. import json
  6. import sys
  7. from collections import OrderedDict
  8. from pprint import pprint
  9. from typing import Any
  10. import onnx
  11. TuningResults = dict[str, Any]
  12. _TUNING_RESULTS_KEY = "tuning_results"
  13. def _find_tuning_results_in_props(metadata_props):
  14. for idx, prop in enumerate(metadata_props):
  15. if prop.key == _TUNING_RESULTS_KEY:
  16. return idx
  17. return -1
  18. def extract(model: onnx.ModelProto):
  19. idx = _find_tuning_results_in_props(model.metadata_props)
  20. if idx < 0:
  21. return None
  22. tuning_results_prop = model.metadata_props[idx]
  23. return json.loads(tuning_results_prop.value)
  24. def embed(model: onnx.ModelProto, tuning_results: list[TuningResults], overwrite=False):
  25. idx = _find_tuning_results_in_props(model.metadata_props)
  26. assert overwrite or idx <= 0, "the supplied onnx file already have tuning results embedded!"
  27. if idx >= 0:
  28. model.metadata_props.pop(idx)
  29. entry = model.metadata_props.add()
  30. entry.key = _TUNING_RESULTS_KEY
  31. entry.value = json.dumps(tuning_results)
  32. return model
  33. class Merger:
  34. class EpAndValidators:
  35. def __init__(self, ep: str, validators: dict[str, str]):
  36. self.ep = ep
  37. self.validators = copy.deepcopy(validators)
  38. self.key = (ep, tuple(sorted(validators.items())))
  39. def __hash__(self):
  40. return hash(self.key)
  41. def __eq__(self, other):
  42. return self.ep == other.ep and self.key == other.key
  43. def __init__(self):
  44. self.ev_to_results = OrderedDict()
  45. def merge(self, tuning_results: list[TuningResults]):
  46. for trs in tuning_results:
  47. self._merge_one(trs)
  48. def get_merged(self):
  49. tuning_results = []
  50. for ev, flat_results in self.ev_to_results.items():
  51. results = {}
  52. trs = {
  53. "ep": ev.ep,
  54. "validators": ev.validators,
  55. "results": results,
  56. }
  57. for (op_sig, params_sig), kernel_id in flat_results.items():
  58. kernel_map = results.setdefault(op_sig, {})
  59. kernel_map[params_sig] = kernel_id
  60. tuning_results.append(trs)
  61. return tuning_results
  62. def _merge_one(self, trs: TuningResults):
  63. ev = Merger.EpAndValidators(trs["ep"], trs["validators"])
  64. flat_results = self.ev_to_results.setdefault(ev, {})
  65. for op_sig, kernel_map in trs["results"].items():
  66. for params_sig, kernel_id in kernel_map.items():
  67. if (op_sig, params_sig) not in flat_results:
  68. flat_results[(op_sig, params_sig)] = kernel_id
  69. def parse_args():
  70. parser = argparse.ArgumentParser()
  71. sub_parsers = parser.add_subparsers(help="Command to execute", dest="cmd")
  72. extract_parser = sub_parsers.add_parser("extract", help="Extract embedded tuning results from an onnx file.")
  73. extract_parser.add_argument("input_onnx")
  74. extract_parser.add_argument("output_json")
  75. embed_parser = sub_parsers.add_parser("embed", help="Embed the tuning results into an onnx file.")
  76. embed_parser.add_argument("--force", "-f", action="store_true", help="Overwrite the tuning results if it existed.")
  77. embed_parser.add_argument("output_onnx", help="Path of the output onnx file.")
  78. embed_parser.add_argument("input_onnx", help="Path of the input onnx file.")
  79. embed_parser.add_argument("input_json", nargs="+", help="Path(s) of the tuning results file(s) to be embedded.")
  80. merge_parser = sub_parsers.add_parser("merge", help="Merge multiple tuning results files as a single one.")
  81. merge_parser.add_argument("output_json", help="Path of the output tuning results file.")
  82. merge_parser.add_argument("input_json", nargs="+", help="Paths of the tuning results files to be merged.")
  83. pprint_parser = sub_parsers.add_parser("pprint", help="Pretty print the tuning results.")
  84. pprint_parser.add_argument("json_or_onnx", help="A tuning results json file or an onnx file.")
  85. args = parser.parse_args()
  86. if len(vars(args)) == 0:
  87. parser.print_help()
  88. exit(-1)
  89. return args
  90. def main():
  91. args = parse_args()
  92. if args.cmd == "extract":
  93. tuning_results = extract(onnx.load_model(args.input_onnx))
  94. if tuning_results is None:
  95. sys.stderr.write(f"{args.input_onnx} does not have tuning results embedded!\n")
  96. sys.exit(-1)
  97. json.dump(tuning_results, open(args.output_json, "w")) # noqa: SIM115
  98. elif args.cmd == "embed":
  99. model = onnx.load_model(args.input_onnx)
  100. merger = Merger()
  101. for tuning_results in [json.load(open(f)) for f in args.input_json]: # noqa: SIM115
  102. merger.merge(tuning_results)
  103. model = embed(model, merger.get_merged(), args.force)
  104. onnx.save_model(model, args.output_onnx)
  105. elif args.cmd == "merge":
  106. merger = Merger()
  107. for tuning_results in [json.load(open(f)) for f in args.input_json]: # noqa: SIM115
  108. merger.merge(tuning_results)
  109. json.dump(merger.get_merged(), open(args.output_json, "w")) # noqa: SIM115
  110. elif args.cmd == "pprint":
  111. tuning_results = None
  112. try: # noqa: SIM105
  113. tuning_results = json.load(open(args.json_or_onnx)) # noqa: SIM115
  114. except Exception:
  115. # it might be an onnx file otherwise, try it latter
  116. pass
  117. if tuning_results is None:
  118. try:
  119. model = onnx.load_model(args.json_or_onnx)
  120. tuning_results = extract(model)
  121. if tuning_results is None:
  122. sys.stderr.write(f"{args.input_onnx} does not have tuning results embedded!\n")
  123. sys.exit(-1)
  124. except Exception:
  125. pass
  126. if tuning_results is None:
  127. sys.stderr.write(f"{args.json_or_onnx} is not a valid tuning results file or onnx file!")
  128. sys.exit(-1)
  129. pprint(tuning_results)
  130. else:
  131. # invalid choice will be handled by the parser
  132. pass
  133. if __name__ == "__main__":
  134. main()