quantizers_utils.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import re
  15. from typing import Any
  16. def get_module_from_name(module, tensor_name: str) -> tuple[Any, str]:
  17. if "." in tensor_name:
  18. module_name, tensor_name = tensor_name.rsplit(".", 1)
  19. module = module.get_submodule(module_name)
  20. return module, tensor_name
  21. def should_convert_module(full_name, patterns: list[str] | None = None):
  22. if patterns is None:
  23. return True
  24. # We should avoid converting in the following situations:
  25. # 1. The pattern appears as a prefix followed by a dot in `full_name`
  26. # (e.g., "model.decoder.layer.11." matches "model.decoder.layer.11.attn.weight").
  27. # 2. The pattern matches `full_name` exactly or via regex
  28. # (e.g., "lm_head" matches "lm_head"; "model.decoder.layer.*" matches "model.decoder.layer.11.attn.weight").
  29. # 3. `full_name` ends with the pattern
  30. # (e.g., "fc1" matches "model.decoder.layers.23.fc1").
  31. should_not_convert = any(
  32. re.match(f"{key}\\.", full_name) or re.match(f"{key}", full_name) or full_name.endswith(key)
  33. for key in patterns
  34. )
  35. return not should_not_convert