io_binding_helper.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487
  1. import copy
  2. import logging
  3. from collections import OrderedDict
  4. from collections.abc import Mapping
  5. from typing import Any
  6. import numpy
  7. import torch
  8. from onnx import TensorProto
  9. from onnxruntime import InferenceSession, RunOptions
  10. # Type alias
  11. ShapeDict = Mapping[str, tuple | list[int]]
  12. logger = logging.getLogger(__name__)
  13. class TypeHelper:
  14. @staticmethod
  15. def get_input_type(ort_session: InferenceSession, name: str) -> str:
  16. for _i, input in enumerate(ort_session.get_inputs()):
  17. if input.name == name:
  18. return input.type
  19. raise ValueError(f"input name {name} not found")
  20. @staticmethod
  21. def get_output_type(ort_session, name: str) -> str:
  22. for _i, output in enumerate(ort_session.get_outputs()):
  23. if output.name == name:
  24. return output.type
  25. raise ValueError(f"output name {name} not found")
  26. @staticmethod
  27. def ort_type_to_numpy_type(ort_type: str):
  28. ort_type_to_numpy_type_map = {
  29. "tensor(int64)": numpy.longlong,
  30. "tensor(int32)": numpy.intc,
  31. "tensor(float)": numpy.float32,
  32. "tensor(float16)": numpy.float16,
  33. "tensor(bool)": bool,
  34. "tensor(uint8)": numpy.uint8,
  35. "tensor(int8)": numpy.int8,
  36. }
  37. if ort_type not in ort_type_to_numpy_type_map:
  38. raise ValueError(f"{ort_type} not found in map")
  39. return ort_type_to_numpy_type_map[ort_type]
  40. @staticmethod
  41. def ort_type_to_torch_type(ort_type: str):
  42. ort_type_to_torch_type_map = {
  43. "tensor(int64)": torch.int64,
  44. "tensor(int32)": torch.int32,
  45. "tensor(float)": torch.float32,
  46. "tensor(float16)": torch.float16,
  47. "tensor(bfloat16)": torch.bfloat16,
  48. "tensor(bool)": torch.bool,
  49. "tensor(uint8)": torch.uint8,
  50. "tensor(int8)": torch.int8,
  51. }
  52. if ort_type not in ort_type_to_torch_type_map:
  53. raise ValueError(f"{ort_type} not found in map")
  54. return ort_type_to_torch_type_map[ort_type]
  55. @staticmethod
  56. def get_io_onnx_type_map(ort_session: InferenceSession) -> dict[str, int]:
  57. """Create a mapping from input/output name to onnx data type"""
  58. name_to_onnx_type = {}
  59. for input in ort_session.get_inputs():
  60. name_to_onnx_type[input.name] = TypeHelper.ort_type_to_onnx_type(input.type)
  61. for output in ort_session.get_outputs():
  62. name_to_onnx_type[output.name] = TypeHelper.ort_type_to_onnx_type(output.type)
  63. return name_to_onnx_type
  64. @staticmethod
  65. def ort_type_to_onnx_type(ort_type: str):
  66. ort_type_to_onnx_type_map = {
  67. "tensor(int64)": TensorProto.INT64,
  68. "tensor(int32)": TensorProto.INT32,
  69. "tensor(float)": TensorProto.FLOAT,
  70. "tensor(float16)": TensorProto.FLOAT16,
  71. "tensor(bfloat16)": TensorProto.BFLOAT16,
  72. "tensor(bool)": TensorProto.BOOL,
  73. "tensor(uint8)": TensorProto.UINT8,
  74. "tensor(int8)": TensorProto.INT8,
  75. }
  76. if ort_type not in ort_type_to_onnx_type_map:
  77. raise ValueError(f"{ort_type} not found in map")
  78. return ort_type_to_onnx_type_map[ort_type]
  79. @staticmethod
  80. def numpy_type_to_torch_type(numpy_type: numpy.dtype):
  81. numpy_type_to_torch_type_map = {
  82. numpy.longlong: torch.int64,
  83. numpy.intc: torch.int32,
  84. numpy.int32: torch.int32,
  85. numpy.float32: torch.float32,
  86. numpy.float16: torch.float16,
  87. bool: torch.bool,
  88. numpy.uint8: torch.uint8,
  89. numpy.int8: torch.int8,
  90. }
  91. if numpy_type not in numpy_type_to_torch_type_map:
  92. raise ValueError(f"{numpy_type} not found in map")
  93. return numpy_type_to_torch_type_map[numpy_type]
  94. @staticmethod
  95. def torch_type_to_numpy_type(torch_type: torch.dtype):
  96. torch_type_to_numpy_type_map = {
  97. torch.int64: numpy.longlong,
  98. torch.int32: numpy.intc,
  99. torch.float32: numpy.float32,
  100. torch.float16: numpy.float16,
  101. torch.bool: bool,
  102. torch.uint8: numpy.uint8,
  103. }
  104. if torch_type not in torch_type_to_numpy_type_map:
  105. raise ValueError(f"{torch_type} not found in map")
  106. return torch_type_to_numpy_type_map[torch_type]
  107. @staticmethod
  108. def get_io_numpy_type_map(ort_session: InferenceSession) -> dict[str, numpy.dtype]:
  109. """Create a mapping from input/output name to numpy data type"""
  110. name_to_numpy_type = {}
  111. for input in ort_session.get_inputs():
  112. name_to_numpy_type[input.name] = TypeHelper.ort_type_to_numpy_type(input.type)
  113. for output in ort_session.get_outputs():
  114. name_to_numpy_type[output.name] = TypeHelper.ort_type_to_numpy_type(output.type)
  115. return name_to_numpy_type
  116. @staticmethod
  117. def get_io_torch_type_map(ort_session: InferenceSession) -> dict[str, torch.dtype]:
  118. """Create a mapping from input/output name to torch data type"""
  119. name_to_torch_type = {}
  120. for input in ort_session.get_inputs():
  121. name_to_torch_type[input.name] = TypeHelper.ort_type_to_torch_type(input.type)
  122. for output in ort_session.get_outputs():
  123. name_to_torch_type[output.name] = TypeHelper.ort_type_to_torch_type(output.type)
  124. return name_to_torch_type
  125. class IOBindingHelper:
  126. @staticmethod
  127. def get_output_buffers(ort_session: InferenceSession, output_shapes, device):
  128. """Returns a dictionary of output name as key, and 1D tensor as value. The tensor has enough space for given shape."""
  129. output_buffers = {}
  130. for name, shape in output_shapes.items():
  131. ort_type = TypeHelper.get_output_type(ort_session, name)
  132. torch_type = TypeHelper.ort_type_to_torch_type(ort_type)
  133. output_buffers[name] = torch.empty(numpy.prod(shape), dtype=torch_type, device=device)
  134. return output_buffers
  135. @staticmethod
  136. def prepare_io_binding(
  137. ort_session,
  138. input_ids: torch.Tensor,
  139. position_ids: torch.Tensor,
  140. attention_mask: torch.Tensor,
  141. past: list[torch.Tensor],
  142. output_buffers,
  143. output_shapes,
  144. ):
  145. """IO binding for a session: bind inputs (input_ids, position_ids, attention_mask, past_*) and outputs."""
  146. name_to_onnx_type = TypeHelper.get_io_onnx_type_map(ort_session)
  147. # Bind inputs and outputs to onnxruntime session
  148. io_binding = ort_session.io_binding()
  149. # Bind inputs
  150. assert input_ids.is_contiguous()
  151. io_binding.bind_input(
  152. "input_ids",
  153. input_ids.device.type,
  154. 0,
  155. name_to_onnx_type["input_ids"],
  156. list(input_ids.size()),
  157. input_ids.data_ptr(),
  158. )
  159. if past is not None:
  160. for i, past_i in enumerate(past):
  161. assert past_i.is_contiguous()
  162. data_ptr = past_i.data_ptr()
  163. if data_ptr == 0:
  164. # When past_sequence_length is 0, its data_ptr will be zero. IO Binding asserts that data_ptr shall not be zero.
  165. # Here we workaround and pass data pointer of input_ids. Actual data is not used for past so it does not matter.
  166. data_ptr = input_ids.data_ptr()
  167. io_binding.bind_input(
  168. f"past_{i}",
  169. past_i.device.type,
  170. 0,
  171. name_to_onnx_type[f"past_{i}"],
  172. list(past_i.size()),
  173. data_ptr,
  174. )
  175. if attention_mask is not None:
  176. assert attention_mask.is_contiguous()
  177. io_binding.bind_input(
  178. "attention_mask",
  179. attention_mask.device.type,
  180. 0,
  181. name_to_onnx_type["attention_mask"],
  182. list(attention_mask.size()),
  183. attention_mask.data_ptr(),
  184. )
  185. if position_ids is not None:
  186. assert position_ids.is_contiguous()
  187. io_binding.bind_input(
  188. "position_ids",
  189. position_ids.device.type,
  190. 0,
  191. name_to_onnx_type["position_ids"],
  192. list(position_ids.size()),
  193. position_ids.data_ptr(),
  194. )
  195. # Bind outputs
  196. for output in ort_session.get_outputs():
  197. output_name = output.name
  198. output_buffer = output_buffers[output_name]
  199. logger.debug(f"{output_name} device type={output_buffer.device.type} shape={list(output_buffer.size())}")
  200. io_binding.bind_output(
  201. output_name,
  202. output_buffer.device.type,
  203. 0,
  204. name_to_onnx_type[output_name],
  205. output_shapes[output_name],
  206. output_buffer.data_ptr(),
  207. )
  208. return io_binding
  209. @staticmethod
  210. def get_outputs_from_io_binding_buffer(ort_session, output_buffers, output_shapes, return_numpy=True):
  211. """Copy results to cpu. Returns a list of numpy array."""
  212. ort_outputs = []
  213. for output in ort_session.get_outputs():
  214. output_name = output.name
  215. buffer = output_buffers[output_name]
  216. shape = output_shapes[output_name]
  217. copy_tensor = buffer[0 : numpy.prod(shape)].reshape(shape).clone().detach()
  218. if return_numpy:
  219. ort_outputs.append(copy_tensor.cpu().numpy())
  220. else:
  221. ort_outputs.append(copy_tensor)
  222. return ort_outputs
  223. class CudaSession:
  224. """Inference Session with IO Binding for ONNX Runtime CUDA or TensorRT provider"""
  225. def __init__(self, ort_session: InferenceSession, device: torch.device, enable_cuda_graph=False):
  226. self.ort_session = ort_session
  227. self.input_names = [input.name for input in self.ort_session.get_inputs()]
  228. self.output_names = [output.name for output in self.ort_session.get_outputs()]
  229. self.io_name_to_onnx_type = TypeHelper.get_io_onnx_type_map(self.ort_session)
  230. self.io_name_to_torch_type = TypeHelper.get_io_torch_type_map(self.ort_session)
  231. self.io_binding = self.ort_session.io_binding()
  232. self.enable_cuda_graph = enable_cuda_graph
  233. self.input_tensors = OrderedDict()
  234. self.output_tensors = OrderedDict()
  235. self.device = device
  236. # Pairs of input and output names that share the same buffer.
  237. self.buffer_sharing: dict[str, str] = {}
  238. def set_buffer_sharing(self, input_name: str, output_name: str):
  239. assert input_name in self.input_names
  240. assert output_name in self.output_names
  241. self.buffer_sharing[input_name] = output_name
  242. self.buffer_sharing[output_name] = input_name
  243. def __del__(self):
  244. del self.input_tensors
  245. del self.output_tensors
  246. del self.io_binding
  247. def bind_input_and_buffer_sharing(self, name: str, tensor: torch.Tensor):
  248. device_id = tensor.device.index if tensor.device.index is not None else 0
  249. tensor_shape = [1] if len(tensor.shape) == 0 else list(tensor.shape)
  250. self.io_binding.bind_input(
  251. name,
  252. tensor.device.type,
  253. device_id,
  254. self.io_name_to_onnx_type[name],
  255. tensor_shape,
  256. tensor.data_ptr(),
  257. )
  258. if name in self.buffer_sharing:
  259. self.io_binding.bind_output(
  260. self.buffer_sharing[name],
  261. tensor.device.type,
  262. device_id,
  263. self.io_name_to_onnx_type[name],
  264. tensor_shape,
  265. tensor.data_ptr(),
  266. )
  267. self.output_tensors[self.buffer_sharing[name]] = tensor
  268. def allocate_buffers(self, shape_dict: ShapeDict):
  269. """Allocate tensors for I/O Binding"""
  270. if self.enable_cuda_graph:
  271. for name, shape in shape_dict.items():
  272. if name in self.input_names:
  273. # Reuse allocated buffer when the shape is same
  274. if name in self.input_tensors:
  275. if tuple(self.input_tensors[name].shape) == tuple(shape):
  276. continue
  277. raise RuntimeError("Expect static input shape for cuda graph")
  278. torch_dtype = self.io_name_to_torch_type[name]
  279. tensor = torch.empty(tuple(shape), dtype=torch_dtype).to(device=self.device)
  280. self.input_tensors[name] = tensor
  281. self.bind_input_and_buffer_sharing(name, tensor)
  282. for name, shape in shape_dict.items():
  283. if name in self.output_names:
  284. # Reuse allocated buffer when the shape is same
  285. if name in self.output_tensors and tuple(self.output_tensors[name].shape) == tuple(shape):
  286. continue
  287. if name in self.buffer_sharing:
  288. continue
  289. torch_dtype = self.io_name_to_torch_type[name]
  290. tensor = torch.empty(tuple(shape), dtype=torch_dtype).to(device=self.device)
  291. self.output_tensors[name] = tensor
  292. self.io_binding.bind_output(
  293. name,
  294. tensor.device.type,
  295. tensor.device.index if tensor.device.index is not None else 0,
  296. self.io_name_to_onnx_type[name],
  297. list(tensor.size()),
  298. tensor.data_ptr(),
  299. )
  300. def infer(self, feed_dict: dict[str, torch.Tensor], run_options: RunOptions = None, synchronize: bool = True):
  301. """Bind input tensors and run inference"""
  302. for name, tensor in feed_dict.items():
  303. assert isinstance(tensor, torch.Tensor) and tensor.is_contiguous()
  304. if name in self.input_names:
  305. if self.enable_cuda_graph:
  306. assert self.input_tensors[name].nelement() == tensor.nelement()
  307. assert self.input_tensors[name].dtype == tensor.dtype
  308. assert tensor.device.type == "cuda"
  309. self.input_tensors[name].copy_(tensor)
  310. else:
  311. self.bind_input_and_buffer_sharing(name, tensor)
  312. if synchronize:
  313. self.io_binding.synchronize_inputs()
  314. self.ort_session.run_with_iobinding(self.io_binding, run_options)
  315. self.io_binding.synchronize_outputs()
  316. else:
  317. self.ort_session.run_with_iobinding(self.io_binding, run_options)
  318. return self.output_tensors
  319. @staticmethod
  320. def get_cuda_provider_options(device_id: int, enable_cuda_graph: bool, stream: int = 0) -> dict[str, Any]:
  321. options = {
  322. "device_id": device_id,
  323. "arena_extend_strategy": "kSameAsRequested",
  324. "enable_cuda_graph": enable_cuda_graph,
  325. }
  326. # Stream is address of a CUDA stream. 0 means the default stream.
  327. if stream != 0:
  328. options["user_compute_stream"] = str(stream)
  329. return options
  330. class GpuBinding(CudaSession):
  331. def __init__(
  332. self,
  333. ort_session: InferenceSession,
  334. device: torch.device,
  335. shape_dict: ShapeDict,
  336. enable_gpu_graph: bool = False,
  337. gpu_graph_id: int = -1,
  338. stream: int = 0,
  339. buffer_sharing: dict[str, str] | None = None,
  340. ):
  341. super().__init__(ort_session, device, enable_gpu_graph)
  342. if buffer_sharing:
  343. for input_name, output_name in buffer_sharing.items():
  344. self.set_buffer_sharing(input_name, output_name)
  345. self.allocate_buffers(shape_dict)
  346. self.gpu_graph_id = gpu_graph_id
  347. # For cuda graph, we need to keep a copy of shape_dict to check if the shape is same in inference later.
  348. self.shape_dict = copy.deepcopy(shape_dict) if enable_gpu_graph else None
  349. self.stream = stream
  350. # The gpu graph id of last run. It will be saved to image metadata.
  351. self.last_run_gpu_graph_id = None
  352. def get_run_options(self, disable_cuda_graph_in_run: bool = False) -> RunOptions:
  353. options = RunOptions()
  354. gpu_graph_id = -1 if disable_cuda_graph_in_run else self.gpu_graph_id
  355. options.add_run_config_entry("gpu_graph_id", str(gpu_graph_id))
  356. self.last_run_gpu_graph_id = gpu_graph_id
  357. return options
  358. def infer(self, feed_dict: dict[str, torch.Tensor], disable_cuda_graph_in_run: bool = False):
  359. run_options = self.get_run_options(disable_cuda_graph_in_run)
  360. if self.stream:
  361. run_options.add_run_config_entry("disable_synchronize_execution_providers", "1")
  362. return super().infer(feed_dict, run_options)
  363. class GpuBindingManager:
  364. """A manager for I/O bindings that support multiple CUDA Graphs.
  365. One cuda graph is reused for same input shape. Automatically add a new cuda graph for new input shape.
  366. """
  367. def __init__(self, ort_session: InferenceSession, device: torch.device, stream: int = 0, max_cuda_graphs: int = 1):
  368. self.ort_session = ort_session
  369. self.device = device
  370. # Binding supports cuda graphs. For a binding, it is able to disable cuda graph for a specific run.
  371. self.graph_bindings = []
  372. # Binding for not using cuda graph.
  373. self.no_graph_binding = None
  374. self.stream = stream
  375. self.max_cuda_graphs = max_cuda_graphs
  376. def get_binding(
  377. self,
  378. shape_dict: ShapeDict,
  379. use_cuda_graph: bool = False,
  380. buffer_sharing: dict[str, str] | None = None,
  381. ) -> GpuBinding:
  382. for gpu_graph_binding in self.graph_bindings:
  383. # Found a cuda graph that captured with the same shape
  384. if gpu_graph_binding.shape_dict == shape_dict:
  385. return gpu_graph_binding
  386. # Reached the maximum number of cuda graphs. Return a binding without cuda graph.
  387. if len(self.graph_bindings) >= self.max_cuda_graphs or (not use_cuda_graph):
  388. if self.no_graph_binding is None:
  389. self.no_graph_binding = GpuBinding(
  390. self.ort_session, self.device, shape_dict, stream=self.stream, buffer_sharing=buffer_sharing
  391. )
  392. else:
  393. self.no_graph_binding.allocate_buffers(shape_dict)
  394. return self.no_graph_binding
  395. # This is a new input shape, create a new cuda graph
  396. gpu_graph_binding = GpuBinding(
  397. self.ort_session,
  398. self.device,
  399. shape_dict,
  400. enable_gpu_graph=True,
  401. gpu_graph_id=len(self.graph_bindings),
  402. stream=self.stream,
  403. buffer_sharing=buffer_sharing,
  404. )
  405. self.graph_bindings.append(gpu_graph_binding)
  406. return gpu_graph_binding