__init__.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. from ._fx import (
  2. create_feature_extractor,
  3. get_graph_node_names,
  4. register_notrace_function,
  5. register_notrace_module,
  6. is_notrace_module,
  7. is_notrace_function,
  8. get_notrace_modules,
  9. get_notrace_functions,
  10. )
  11. from .activations import *
  12. from .adaptive_avgmax_pool import (
  13. adaptive_avgmax_pool2d,
  14. select_adaptive_pool2d,
  15. AdaptiveAvgMaxPool2d,
  16. SelectAdaptivePool2d,
  17. )
  18. from .attention import Attention, AttentionRope, maybe_add_mask, resolve_self_attn_mask
  19. from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttentionV2
  20. from .attention_pool import AttentionPoolLatent, AttentionPoolPrr
  21. from .attention_pool2d import AttentionPool2d, RotAttentionPool2d
  22. from .blur_pool import BlurPool2d, create_aa
  23. from .classifier import create_classifier, ClassifierHead, NormMlpClassifierHead, ClNormMlpClassifierHead
  24. from .cond_conv2d import CondConv2d, get_condconv_initializer
  25. from .config import (
  26. is_exportable,
  27. is_scriptable,
  28. is_no_jit,
  29. use_fused_attn,
  30. set_exportable,
  31. set_scriptable,
  32. set_no_jit,
  33. set_layer_config,
  34. set_fused_attn,
  35. set_reentrant_ckpt,
  36. use_reentrant_ckpt,
  37. )
  38. from .conv2d_same import Conv2dSame, conv2d_same
  39. from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct
  40. from .create_act import create_act_layer, get_act_layer, get_act_fn
  41. from .create_attn import get_attn, create_attn
  42. from .create_conv2d import create_conv2d
  43. from .create_norm import get_norm_layer, create_norm_layer
  44. from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer
  45. from .diff_attention import DiffAttention
  46. from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path, calculate_drop_path_rates
  47. from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn
  48. from .evo_norm import (
  49. EvoNorm2dB0,
  50. EvoNorm2dB1,
  51. EvoNorm2dB2,
  52. EvoNorm2dS0,
  53. EvoNorm2dS0a,
  54. EvoNorm2dS1,
  55. EvoNorm2dS1a,
  56. EvoNorm2dS2,
  57. EvoNorm2dS2a,
  58. )
  59. from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm
  60. from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d
  61. from .format import Format, get_channel_dim, get_spatial_dim, nchw_to, nhwc_to
  62. from .gather_excite import GatherExcite
  63. from .global_context import GlobalContext
  64. from .grid import ndgrid, meshgrid
  65. from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple
  66. from .hybrid_embed import HybridEmbed, HybridEmbedWithSize
  67. from .inplace_abn import InplaceAbn
  68. from .layer_scale import LayerScale, LayerScale2d
  69. from .linear import Linear
  70. from .mixed_conv2d import MixedConv2d
  71. from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, SwiGLUPacked, ConvMlp, GlobalResponseNormMlp
  72. from .non_local_attn import NonLocalAttn, BatNonLocalAttn
  73. from .norm import (
  74. GroupNorm,
  75. GroupNorm1,
  76. LayerNorm,
  77. LayerNorm2d,
  78. LayerNormFp32,
  79. LayerNorm2dFp32,
  80. RmsNorm,
  81. RmsNorm2d,
  82. RmsNormFp32,
  83. RmsNorm2dFp32,
  84. SimpleNorm,
  85. SimpleNorm2d,
  86. SimpleNormFp32,
  87. SimpleNorm2dFp32,
  88. )
  89. from .norm_act import (
  90. BatchNormAct2d,
  91. GroupNormAct,
  92. GroupNorm1Act,
  93. LayerNormAct,
  94. LayerNormAct2d,
  95. LayerNormActFp32,
  96. LayerNormAct2dFp32,
  97. RmsNormAct,
  98. RmsNormAct2d,
  99. RmsNormActFp32,
  100. RmsNormAct2dFp32,
  101. SyncBatchNormAct,
  102. convert_sync_batchnorm,
  103. FrozenBatchNormAct2d,
  104. freeze_batch_norm_2d,
  105. unfreeze_batch_norm_2d,
  106. )
  107. from .padding import get_padding, get_same_padding, pad_same
  108. from .patch_dropout import PatchDropout, PatchDropoutWithIndices, patch_dropout_forward
  109. from .patch_embed import PatchEmbed, PatchEmbedWithSize, PatchEmbedInterpolator, resample_patch_embed
  110. from .pool1d import global_pool_nlc
  111. from .other_pool import LsePlus2d, LsePlus1d, SimPool2d, SimPool1d
  112. from .pool2d_same import AvgPool2dSame, create_pool2d
  113. from .pos_embed import resample_abs_pos_embed, resample_abs_pos_embed_nhwc
  114. from .pos_embed_rel import (
  115. RelPosMlp,
  116. RelPosBias,
  117. RelPosBiasTf,
  118. gen_relative_position_index,
  119. gen_relative_log_coords,
  120. resize_rel_pos_bias_table,
  121. resize_rel_pos_bias_table_simple,
  122. resize_rel_pos_bias_table_levit,
  123. )
  124. from .pos_embed_sincos import (
  125. pixel_freq_bands,
  126. freq_bands,
  127. build_sincos2d_pos_embed,
  128. build_fourier_pos_embed,
  129. build_rotary_pos_embed,
  130. apply_rot_embed,
  131. apply_rot_embed_cat,
  132. apply_rot_embed_list,
  133. apply_keep_indices_nlc,
  134. FourierEmbed,
  135. RotaryEmbedding,
  136. RotaryEmbeddingCat,
  137. RotaryEmbeddingMixed,
  138. RotaryEmbeddingDinoV3,
  139. get_mixed_freqs,
  140. create_rope_embed,
  141. )
  142. from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
  143. from .selective_kernel import SelectiveKernel
  144. from .separable_conv import SeparableConv2d, SeparableConvNormAct
  145. from .space_to_depth import SpaceToDepth, DepthToSpace
  146. from .split_attn import SplitAttn
  147. from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
  148. from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
  149. from .test_time_pool import TestTimePoolHead, apply_test_time_pool
  150. from .trace_utils import _assert, _float_to_int
  151. from .typing import LayerType, PadType, disable_compiler
  152. from .weight_init import (
  153. is_meta_device,
  154. trunc_normal_,
  155. trunc_normal_tf_,
  156. variance_scaling_,
  157. lecun_normal_,
  158. init_weight_jax,
  159. init_weight_vit,
  160. )