config.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. """ Model / Layer Config singleton state
  2. """
  3. import os
  4. import warnings
  5. from typing import Any, Optional
  6. import torch
  7. __all__ = [
  8. 'is_exportable', 'is_scriptable', 'is_no_jit', 'use_fused_attn',
  9. 'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config', 'set_fused_attn',
  10. 'set_reentrant_ckpt', 'use_reentrant_ckpt'
  11. ]
  12. # Set to True if prefer to have layers with no jit optimization (includes activations)
  13. _NO_JIT = False
  14. # Set to True if prefer to have activation layers with no jit optimization
  15. # NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying
  16. # the jit flags so far are activations. This will change as more layers are updated and/or added.
  17. _NO_ACTIVATION_JIT = False
  18. # Set to True if exporting a model with Same padding via ONNX
  19. _EXPORTABLE = False
  20. # Set to True if wanting to use torch.jit.script on a model
  21. _SCRIPTABLE = False
  22. # use torch.scaled_dot_product_attention where possible
  23. _HAS_FUSED_ATTN = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
  24. if 'TIMM_FUSED_ATTN' in os.environ:
  25. _USE_FUSED_ATTN = int(os.environ['TIMM_FUSED_ATTN'])
  26. else:
  27. _USE_FUSED_ATTN = 1 # 0 == off, 1 == on (for tested use), 2 == on (for experimental use)
  28. if 'TIMM_REENTRANT_CKPT' in os.environ:
  29. _USE_REENTRANT_CKPT = bool(os.environ['TIMM_REENTRANT_CKPT'])
  30. else:
  31. _USE_REENTRANT_CKPT = False # defaults to disabled (off)
  32. def is_no_jit():
  33. return _NO_JIT
  34. class set_no_jit:
  35. def __init__(self, mode: bool) -> None:
  36. global _NO_JIT
  37. self.prev = _NO_JIT
  38. _NO_JIT = mode
  39. def __enter__(self) -> None:
  40. pass
  41. def __exit__(self, *args: Any) -> bool:
  42. global _NO_JIT
  43. _NO_JIT = self.prev
  44. return False
  45. def is_exportable():
  46. return _EXPORTABLE
  47. class set_exportable:
  48. def __init__(self, mode: bool) -> None:
  49. global _EXPORTABLE
  50. self.prev = _EXPORTABLE
  51. _EXPORTABLE = mode
  52. def __enter__(self) -> None:
  53. pass
  54. def __exit__(self, *args: Any) -> bool:
  55. global _EXPORTABLE
  56. _EXPORTABLE = self.prev
  57. return False
  58. def is_scriptable():
  59. return _SCRIPTABLE
  60. class set_scriptable:
  61. def __init__(self, mode: bool) -> None:
  62. global _SCRIPTABLE
  63. self.prev = _SCRIPTABLE
  64. _SCRIPTABLE = mode
  65. def __enter__(self) -> None:
  66. pass
  67. def __exit__(self, *args: Any) -> bool:
  68. global _SCRIPTABLE
  69. _SCRIPTABLE = self.prev
  70. return False
  71. class set_layer_config:
  72. """ Layer config context manager that allows setting all layer config flags at once.
  73. If a flag arg is None, it will not change the current value.
  74. """
  75. def __init__(
  76. self,
  77. scriptable: Optional[bool] = None,
  78. exportable: Optional[bool] = None,
  79. no_jit: Optional[bool] = None,
  80. no_activation_jit: Optional[bool] = None):
  81. global _SCRIPTABLE
  82. global _EXPORTABLE
  83. global _NO_JIT
  84. global _NO_ACTIVATION_JIT
  85. self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT
  86. if scriptable is not None:
  87. _SCRIPTABLE = scriptable
  88. if exportable is not None:
  89. _EXPORTABLE = exportable
  90. if no_jit is not None:
  91. _NO_JIT = no_jit
  92. if no_activation_jit is not None:
  93. _NO_ACTIVATION_JIT = no_activation_jit
  94. def __enter__(self) -> None:
  95. pass
  96. def __exit__(self, *args: Any) -> bool:
  97. global _SCRIPTABLE
  98. global _EXPORTABLE
  99. global _NO_JIT
  100. global _NO_ACTIVATION_JIT
  101. _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev
  102. return False
  103. def use_fused_attn(experimental: bool = False) -> bool:
  104. # NOTE: ONNX export cannot handle F.scaled_dot_product_attention as of pytorch 2.0
  105. if not _HAS_FUSED_ATTN or _EXPORTABLE:
  106. return False
  107. if experimental:
  108. return _USE_FUSED_ATTN > 1
  109. return _USE_FUSED_ATTN > 0
  110. def set_fused_attn(enable: bool = True, experimental: bool = False):
  111. global _USE_FUSED_ATTN
  112. if not _HAS_FUSED_ATTN:
  113. warnings.warn('This version of pytorch does not have F.scaled_dot_product_attention, fused_attn flag ignored.')
  114. return
  115. if experimental and enable:
  116. _USE_FUSED_ATTN = 2
  117. elif enable:
  118. _USE_FUSED_ATTN = 1
  119. else:
  120. _USE_FUSED_ATTN = 0
  121. def use_reentrant_ckpt() -> bool:
  122. return _USE_REENTRANT_CKPT
  123. def set_reentrant_ckpt(enable: bool = True):
  124. global _USE_REENTRANT_CKPT
  125. _USE_REENTRANT_CKPT = enable