__init__.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. # mypy: allow-untyped-defs
  2. from .fake_quantize import * # noqa: F403
  3. from .fuse_modules import fuse_modules
  4. from .fuser_method_mappings import * # noqa: F403
  5. from .observer import * # noqa: F403
  6. from .qconfig import * # noqa: F403
  7. from .quant_type import * # noqa: F403
  8. from .quantization_mappings import * # noqa: F403
  9. from .quantize import * # noqa: F403
  10. from .quantize_jit import * # noqa: F403
  11. from .stubs import * # noqa: F403
  12. def default_eval_fn(model, calib_data):
  13. r"""
  14. Default evaluation function takes a torch.utils.data.Dataset or a list of
  15. input Tensors and run the model on the dataset
  16. """
  17. for data, _target in calib_data:
  18. model(data)
  19. __all__ = [
  20. "QuantWrapper",
  21. "QuantStub",
  22. "DeQuantStub",
  23. # Top level API for eager mode quantization
  24. "quantize",
  25. "quantize_dynamic",
  26. "quantize_qat",
  27. "prepare",
  28. "convert",
  29. "prepare_qat",
  30. # Top level API for graph mode quantization on TorchScript
  31. "quantize_jit",
  32. "quantize_dynamic_jit",
  33. # pyrefly: ignore [bad-dunder-all]
  34. "_prepare_ondevice_dynamic_jit",
  35. # pyrefly: ignore [bad-dunder-all]
  36. "_convert_ondevice_dynamic_jit",
  37. # pyrefly: ignore [bad-dunder-all]
  38. "_quantize_ondevice_dynamic_jit",
  39. # Top level API for graph mode quantization on GraphModule(torch.fx)
  40. # 'fuse_fx', 'quantize_fx', # TODO: add quantize_dynamic_fx
  41. # 'prepare_fx', 'prepare_dynamic_fx', 'convert_fx',
  42. "QuantType", # quantization type
  43. # custom module APIs
  44. "get_default_static_quant_module_mappings",
  45. "get_static_quant_module_class",
  46. "get_default_dynamic_quant_module_mappings",
  47. "get_default_qat_module_mappings",
  48. "get_default_qconfig_propagation_list",
  49. "get_default_compare_output_module_list",
  50. "get_quantized_operator",
  51. "get_fuser_method",
  52. # Sub functions for `prepare` and `swap_module`
  53. "propagate_qconfig_",
  54. "add_quant_dequant",
  55. "swap_module",
  56. "default_eval_fn",
  57. # Observers
  58. "ObserverBase",
  59. # pyrefly: ignore [bad-dunder-all]
  60. "WeightObserver",
  61. "HistogramObserver",
  62. "observer",
  63. "default_observer",
  64. "default_weight_observer",
  65. "default_placeholder_observer",
  66. "default_per_channel_weight_observer",
  67. # FakeQuantize (for qat)
  68. "default_fake_quant",
  69. "default_weight_fake_quant",
  70. "default_fixed_qparams_range_neg1to1_fake_quant",
  71. "default_fixed_qparams_range_0to1_fake_quant",
  72. "default_per_channel_weight_fake_quant",
  73. "default_histogram_fake_quant",
  74. # QConfig
  75. "QConfig",
  76. "default_qconfig",
  77. "default_dynamic_qconfig",
  78. "float16_dynamic_qconfig",
  79. "float_qparams_weight_only_qconfig",
  80. # QAT utilities
  81. "default_qat_qconfig",
  82. "prepare_qat",
  83. "quantize_qat",
  84. # module transformations
  85. "fuse_modules",
  86. ]