stubs.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. from typing import Any
  2. import torch
  3. from torch import nn
  4. from torch.ao.quantization import QConfig
  5. __all__ = ["QuantStub", "DeQuantStub", "QuantWrapper"]
  6. class QuantStub(nn.Module):
  7. r"""Quantize stub module, before calibration, this is same as an observer,
  8. it will be swapped as `nnq.Quantize` in `convert`.
  9. Args:
  10. qconfig: quantization configuration for the tensor,
  11. if qconfig is not provided, we will get qconfig from parent modules
  12. """
  13. def __init__(self, qconfig: QConfig | None = None):
  14. super().__init__()
  15. if qconfig:
  16. self.qconfig = qconfig
  17. def forward(self, x: torch.Tensor) -> torch.Tensor:
  18. return x
  19. class DeQuantStub(nn.Module):
  20. r"""Dequantize stub module, before calibration, this is same as identity,
  21. this will be swapped as `nnq.DeQuantize` in `convert`.
  22. Args:
  23. qconfig: quantization configuration for the tensor,
  24. if qconfig is not provided, we will get qconfig from parent modules
  25. """
  26. def __init__(self, qconfig: Any | None = None):
  27. super().__init__()
  28. if qconfig:
  29. self.qconfig = qconfig
  30. def forward(self, x: torch.Tensor) -> torch.Tensor:
  31. return x
  32. class QuantWrapper(nn.Module):
  33. r"""A wrapper class that wraps the input module, adds QuantStub and
  34. DeQuantStub and surround the call to module with call to quant and dequant
  35. modules.
  36. This is used by the `quantization` utility functions to add the quant and
  37. dequant modules, before `convert` function `QuantStub` will just be observer,
  38. it observes the input tensor, after `convert`, `QuantStub`
  39. will be swapped to `nnq.Quantize` which does actual quantization. Similarly
  40. for `DeQuantStub`.
  41. """
  42. quant: QuantStub
  43. dequant: DeQuantStub
  44. module: nn.Module
  45. def __init__(self, module: nn.Module):
  46. super().__init__()
  47. qconfig = getattr(module, "qconfig", None)
  48. self.add_module("quant", QuantStub(qconfig))
  49. self.add_module("dequant", DeQuantStub(qconfig))
  50. self.add_module("module", module)
  51. self.train(module.training)
  52. def forward(self, X: torch.Tensor) -> torch.Tensor:
  53. X = self.quant(X)
  54. X = self.module(X)
  55. return self.dequant(X)