configuration_tvp.py 3.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. # Copyright 2023 The Intel AIA Team Authors, and HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License=, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing=, software
  10. # distributed under the License is distributed on an "AS IS" BASIS=,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND=, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """TVP model configuration"""
  15. from huggingface_hub.dataclasses import strict
  16. from ...backbone_utils import consolidate_backbone_kwargs_to_config
  17. from ...configuration_utils import PreTrainedConfig
  18. from ...utils import auto_docstring
  19. from ..auto import AutoConfig
  20. @auto_docstring(checkpoint="Intel/tvp-base")
  21. @strict
  22. class TvpConfig(PreTrainedConfig):
  23. r"""
  24. distance_loss_weight (`float`, *optional*, defaults to 1.0):
  25. The weight of distance loss.
  26. duration_loss_weight (`float`, *optional*, defaults to 0.1):
  27. The weight of duration loss.
  28. visual_prompter_type (`str`, *optional*, defaults to `"framepad"`):
  29. Visual prompt type. The type of padding. Framepad means padding on each frame. Should be one of "framepad"
  30. or "framedownpad"
  31. visual_prompter_apply (`str`, *optional*, defaults to `"replace"`):
  32. The way of applying visual prompt. Replace means use the value of prompt to change the original value in
  33. visual inputs. Should be one of "replace", or "add", or "remove".
  34. visual_prompt_size (`int`, *optional*, defaults to 96):
  35. The size of visual prompt.
  36. max_img_size (`int`, *optional*, defaults to 448):
  37. The maximum size of frame.
  38. num_frames (`int`, *optional*, defaults to 48):
  39. The number of frames extracted from a video.
  40. max_position_embeddings (`int`, *optional*, defaults to 512):
  41. The maximum sequence length that this model might ever be used with. Typically set this to something large
  42. just in case (e.g., 512 or 1024 or 2048).
  43. max_grid_col_position_embeddings (`int`, *optional*, defaults to 100):
  44. The largest number of horizontal patches from a video frame.
  45. max_grid_row_position_embeddings (`int`, *optional*, defaults to 100):
  46. The largest number of vertical patches from a video frame.
  47. """
  48. model_type = "tvp"
  49. sub_configs = {"backbone_config": AutoConfig}
  50. backbone_config: dict | PreTrainedConfig | None = None
  51. distance_loss_weight: float = 1.0
  52. duration_loss_weight: float = 0.1
  53. visual_prompter_type: str = "framepad"
  54. visual_prompter_apply: str = "replace"
  55. visual_prompt_size: int = 96
  56. max_img_size: int = 448
  57. num_frames: int = 48
  58. vocab_size: int = 30522
  59. type_vocab_size: int = 2
  60. hidden_size: int = 768
  61. intermediate_size: int = 3072
  62. num_hidden_layers: int = 12
  63. num_attention_heads: int = 12
  64. max_position_embeddings: int = 512
  65. max_grid_col_position_embeddings: int = 100
  66. max_grid_row_position_embeddings: int = 100
  67. hidden_dropout_prob: float | int = 0.1
  68. hidden_act: str = "gelu"
  69. layer_norm_eps: float = 1e-12
  70. initializer_range: float = 0.02
  71. attention_probs_dropout_prob: float | int = 0.1
  72. pad_token_id: int | None = None
  73. def __post_init__(self, **kwargs):
  74. self.backbone_config, kwargs = consolidate_backbone_kwargs_to_config(
  75. backbone_config=self.backbone_config,
  76. default_config_type="resnet",
  77. default_config_kwargs={"out_features": ["stage4"]},
  78. **kwargs,
  79. )
  80. super().__post_init__(**kwargs)
  81. __all__ = ["TvpConfig"]