huggingface_models.py 3.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. # Maps model class name to a tuple of model class
  7. MODEL_CLASSES = [
  8. "AutoModel",
  9. "AutoModelWithLMHead",
  10. "AutoModelForSequenceClassification",
  11. "AutoModelForQuestionAnswering",
  12. "AutoModelForCausalLM",
  13. ]
  14. # Pretrained model name to a tuple of input names, opset_version, use_external_data_format, optimization model type
  15. # Some models like GPT, T5, Bart etc has its own convert_to_onnx.py in models sub-directory, and they are excluded here.
  16. MODELS = {
  17. # BERT
  18. "bert-base-cased": (["input_ids", "attention_mask", "token_type_ids"], 16, False, "bert"),
  19. "bert-large-cased": (["input_ids", "attention_mask", "token_type_ids"], 16, False, "bert"),
  20. # Transformer-XL (Models uses Einsum, which need opset version 16 or later.)
  21. "transfo-xl-wt103": (["input_ids", "mems"], 16, False, "bert"),
  22. # XLNet
  23. "xlnet-base-cased": (["input_ids"], 16, False, "bert"),
  24. "xlnet-large-cased": (["input_ids"], 16, False, "bert"),
  25. # XLM
  26. "xlm-mlm-en-2048": (["input_ids"], 16, True, "bert"),
  27. "xlm-mlm-ende-1024": (["input_ids"], 16, False, "bert"),
  28. "xlm-mlm-enfr-1024": (["input_ids"], 16, False, "bert"),
  29. # RoBERTa
  30. "roberta-base": (["input_ids", "attention_mask"], 16, False, "bert"),
  31. "roberta-large": (["input_ids", "attention_mask"], 16, False, "bert"),
  32. "roberta-large-mnli": (["input_ids", "attention_mask"], 16, False, "bert"),
  33. "deepset/roberta-base-squad2": (["input_ids", "attention_mask"], 16, False, "bert"),
  34. "distilroberta-base": (["input_ids", "attention_mask"], 16, False, "bert"),
  35. # DistilBERT
  36. "distilbert-base-uncased": (["input_ids", "attention_mask"], 16, False, "bert"),
  37. "distilbert-base-uncased-distilled-squad": (["input_ids", "attention_mask"], 16, False, "bert"),
  38. # CTRL
  39. "ctrl": (["input_ids"], 16, True, "bert"),
  40. # CamemBERT
  41. "camembert-base": (["input_ids"], 16, False, "bert"),
  42. # ALBERT
  43. "albert-base-v1": (["input_ids"], 16, False, "bert"),
  44. "albert-large-v1": (["input_ids"], 16, False, "bert"),
  45. "albert-xlarge-v1": (["input_ids"], 16, True, "bert"),
  46. # "albert-xxlarge-v1": (["input_ids"], 16, True, "bert"),
  47. "albert-base-v2": (["input_ids"], 16, False, "bert"),
  48. "albert-large-v2": (["input_ids"], 16, False, "bert"),
  49. "albert-xlarge-v2": (["input_ids"], 16, True, "bert"),
  50. # "albert-xxlarge-v2": (["input_ids"], 16, True, "bert"),
  51. # XLM-RoBERTa
  52. "xlm-roberta-base": (["input_ids"], 16, False, "bert"),
  53. "xlm-roberta-large": (["input_ids"], 16, True, "bert"),
  54. # FlauBERT
  55. "flaubert/flaubert_small_cased": (["input_ids"], 16, False, "bert"),
  56. "flaubert/flaubert_base_cased": (["input_ids"], 16, False, "bert"),
  57. # "flaubert/flaubert_large_cased": (["input_ids"], 16, False, "bert"),
  58. # Layoutlm
  59. "microsoft/layoutlm-base-uncased": (["input_ids"], 16, False, "bert"),
  60. "microsoft/layoutlm-large-uncased": (["input_ids"], 16, False, "bert"),
  61. # Squeezebert
  62. "squeezebert/squeezebert-uncased": (["input_ids"], 16, False, "bert"),
  63. "squeezebert/squeezebert-mnli": (["input_ids"], 16, False, "bert"),
  64. "squeezebert/squeezebert-mnli-headless": (["input_ids"], 16, False, "bert"),
  65. "unc-nlp/lxmert-base-uncased": (["input_ids", "visual_feats", "visual_pos"], 16, False, "bert"),
  66. # ViT
  67. "google/vit-base-patch16-224": (["pixel_values"], 16, False, "vit"),
  68. # Swin
  69. "microsoft/swin-base-patch4-window7-224": (["pixel_values"], 16, False, "swin"),
  70. "microsoft/swin-small-patch4-window7-224": (["pixel_values"], 16, False, "swin"),
  71. "microsoft/swin-tiny-patch4-window7-224": (["pixel_values"], 16, False, "swin"),
  72. }