hf.py 966 B

1234567891011121314151617181920212223
  1. import json
  2. import torch
  3. from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
  4. from transformers.utils.hub import cached_file
  5. def load_config_hf(model_name):
  6. resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
  7. return json.load(open(resolved_archive_file))
  8. def load_state_dict_hf(model_name, device=None, dtype=None):
  9. # If not fp32, then we don't want to load directly to the GPU
  10. mapped_device = "cpu" if dtype not in [torch.float32, None] else device
  11. resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
  12. return torch.load(resolved_archive_file, map_location=mapped_device)
  13. # Convert dtype before moving to GPU to save memory
  14. if dtype is not None:
  15. state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
  16. state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
  17. return state_dict