pos_embed.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. """ Position Embedding Utilities
  2. Hacked together by / Copyright 2022 Ross Wightman
  3. """
  4. import logging
  5. import math
  6. from typing import List, Tuple, Optional, Union
  7. import torch
  8. import torch.nn.functional as F
  9. from ._fx import register_notrace_function
  10. _logger = logging.getLogger(__name__)
  11. @torch.fx.wrap
  12. @register_notrace_function
  13. def resample_abs_pos_embed(
  14. posemb: torch.Tensor,
  15. new_size: List[int],
  16. old_size: Optional[List[int]] = None,
  17. num_prefix_tokens: int = 1,
  18. interpolation: str = 'bicubic',
  19. antialias: bool = True,
  20. verbose: bool = False,
  21. ):
  22. # sort out sizes, assume square if old size not provided
  23. num_pos_tokens = posemb.shape[1]
  24. num_new_tokens = new_size[0] * new_size[1] + num_prefix_tokens
  25. if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]:
  26. return posemb
  27. if old_size is None:
  28. hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens))
  29. old_size = hw, hw
  30. if num_prefix_tokens:
  31. posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:]
  32. else:
  33. posemb_prefix, posemb = None, posemb
  34. # do the interpolation
  35. embed_dim = posemb.shape[-1]
  36. orig_dtype = posemb.dtype
  37. posemb = posemb.float() # interpolate needs float32
  38. posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2)
  39. posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
  40. posemb = posemb.permute(0, 2, 3, 1).reshape(1, -1, embed_dim)
  41. posemb = posemb.to(orig_dtype)
  42. # add back extra (class, etc) prefix tokens
  43. if posemb_prefix is not None:
  44. posemb = torch.cat([posemb_prefix, posemb], dim=1)
  45. if not torch.jit.is_scripting() and verbose:
  46. _logger.info(f'Resized position embedding: {old_size} to {new_size}.')
  47. return posemb
  48. @torch.fx.wrap
  49. @register_notrace_function
  50. def resample_abs_pos_embed_nhwc(
  51. posemb: torch.Tensor,
  52. new_size: List[int],
  53. interpolation: str = 'bicubic',
  54. antialias: bool = True,
  55. verbose: bool = False,
  56. ):
  57. if new_size[0] == posemb.shape[-3] and new_size[1] == posemb.shape[-2]:
  58. return posemb
  59. orig_dtype = posemb.dtype
  60. posemb = posemb.float()
  61. posemb = posemb.reshape(1, posemb.shape[-3], posemb.shape[-2], posemb.shape[-1]).permute(0, 3, 1, 2)
  62. posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
  63. posemb = posemb.permute(0, 2, 3, 1).to(orig_dtype)
  64. if not torch.jit.is_scripting() and verbose:
  65. _logger.info(f'Resized position embedding: {posemb.shape[-3:-1]} to {new_size}.')
  66. return posemb