util.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. #
  2. # The implementation of this file is based on:
  3. # https://github.com/intel/neural-compressor/tree/master/neural_compressor
  4. #
  5. # Copyright (c) 2023 Intel Corporation
  6. #
  7. # Licensed under the Apache License, Version 2.0 (the "License");
  8. # you may not use this file except in compliance with the License.
  9. # You may obtain a copy of the License at
  10. #
  11. # http://www.apache.org/licenses/LICENSE-2.0
  12. #
  13. # Unless required by applicable law or agreed to in writing, software
  14. # distributed under the License is distributed on an "AS IS" BASIS,
  15. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. # See the License for the specific language governing permissions and
  17. # limitations under the License.
  18. """Helper classes or functions for onnxrt adaptor."""
  19. import importlib
  20. import logging
  21. import numpy as np
  22. logger = logging.getLogger("neural_compressor")
  23. MAXIMUM_PROTOBUF = 2147483648
  24. def simple_progress_bar(total, i):
  25. """Progress bar for cases where tqdm can't be used."""
  26. progress = i / total
  27. bar_length = 20
  28. bar = "#" * int(bar_length * progress)
  29. spaces = " " * (bar_length - len(bar))
  30. percentage = progress * 100
  31. print(f"\rProgress: [{bar}{spaces}] {percentage:.2f}%", end="")
  32. def find_by_name(name, item_list):
  33. """Helper function to find item by name in a list."""
  34. items = []
  35. for item in item_list:
  36. assert hasattr(item, "name"), f"{item} should have a 'name' attribute defined" # pragma: no cover
  37. if item.name == name:
  38. items.append(item)
  39. if len(items) > 0:
  40. return items[0]
  41. else:
  42. return None
  43. def to_numpy(data):
  44. """Convert to numpy ndarrays."""
  45. import torch # noqa: PLC0415
  46. if not isinstance(data, np.ndarray):
  47. if not importlib.util.find_spec("torch"):
  48. logger.error(
  49. "Please install torch to enable subsequent data type check and conversion, "
  50. "or reorganize your data format to numpy array."
  51. )
  52. exit(0)
  53. if isinstance(data, torch.Tensor):
  54. if data.dtype is torch.bfloat16: # pragma: no cover
  55. return data.detach().cpu().to(torch.float32).numpy()
  56. if data.dtype is torch.chalf: # pragma: no cover
  57. return data.detach().cpu().to(torch.cfloat).numpy()
  58. return data.detach().cpu().numpy()
  59. else:
  60. try:
  61. return np.array(data)
  62. except Exception:
  63. assert False, ( # noqa: B011
  64. f"The input data for onnx model is {type(data)}, which is not supported to convert to numpy ndarrays."
  65. )
  66. else:
  67. return data