| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374 |
- from typing import Any
- import torch
- from torch import nn
- from torch.ao.quantization import QConfig
- __all__ = ["QuantStub", "DeQuantStub", "QuantWrapper"]
- class QuantStub(nn.Module):
- r"""Quantize stub module, before calibration, this is same as an observer,
- it will be swapped as `nnq.Quantize` in `convert`.
- Args:
- qconfig: quantization configuration for the tensor,
- if qconfig is not provided, we will get qconfig from parent modules
- """
- def __init__(self, qconfig: QConfig | None = None):
- super().__init__()
- if qconfig:
- self.qconfig = qconfig
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- return x
- class DeQuantStub(nn.Module):
- r"""Dequantize stub module, before calibration, this is same as identity,
- this will be swapped as `nnq.DeQuantize` in `convert`.
- Args:
- qconfig: quantization configuration for the tensor,
- if qconfig is not provided, we will get qconfig from parent modules
- """
- def __init__(self, qconfig: Any | None = None):
- super().__init__()
- if qconfig:
- self.qconfig = qconfig
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- return x
- class QuantWrapper(nn.Module):
- r"""A wrapper class that wraps the input module, adds QuantStub and
- DeQuantStub and surround the call to module with call to quant and dequant
- modules.
- This is used by the `quantization` utility functions to add the quant and
- dequant modules, before `convert` function `QuantStub` will just be observer,
- it observes the input tensor, after `convert`, `QuantStub`
- will be swapped to `nnq.Quantize` which does actual quantization. Similarly
- for `DeQuantStub`.
- """
- quant: QuantStub
- dequant: DeQuantStub
- module: nn.Module
- def __init__(self, module: nn.Module):
- super().__init__()
- qconfig = getattr(module, "qconfig", None)
- self.add_module("quant", QuantStub(qconfig))
- self.add_module("dequant", DeQuantStub(qconfig))
- self.add_module("module", module)
- self.train(module.training)
- def forward(self, X: torch.Tensor) -> torch.Tensor:
- X = self.quant(X)
- X = self.module(X)
- return self.dequant(X)
|