_mappings.py 597 B

1234567891011121314151617181920212223
  1. # mypy: allow-untyped-defs
  2. __all__ = [
  3. "get_static_sparse_quantized_mapping",
  4. "get_dynamic_sparse_quantized_mapping",
  5. ]
  6. def get_static_sparse_quantized_mapping():
  7. import torch.ao.nn.sparse
  8. _static_sparse_quantized_mapping = {
  9. torch.nn.Linear: torch.ao.nn.sparse.quantized.Linear,
  10. }
  11. return _static_sparse_quantized_mapping
  12. def get_dynamic_sparse_quantized_mapping():
  13. import torch.ao.nn.sparse
  14. _dynamic_sparse_quantized_mapping = {
  15. torch.nn.Linear: torch.ao.nn.sparse.quantized.dynamic.Linear,
  16. }
  17. return _dynamic_sparse_quantized_mapping