pos_embed_sincos.py 47 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272
  1. """ Sin-cos, fourier, rotary position embedding modules and functions
  2. Hacked together by / Copyright 2022 Ross Wightman
  3. """
  4. import math
  5. from typing import List, Tuple, Optional, Union
  6. import torch
  7. from torch import nn as nn
  8. from ._fx import register_notrace_function
  9. from .grid import ndgrid
  10. from .trace_utils import _assert
  11. def pixel_freq_bands(
  12. num_bands: int,
  13. max_freq: float = 224.,
  14. linear_bands: bool = True,
  15. device: Optional[torch.device] = None,
  16. ):
  17. if linear_bands:
  18. bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=torch.float32, device=device)
  19. else:
  20. bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=torch.float32, device=device)
  21. return bands * torch.pi
  22. def freq_bands(
  23. num_bands: int,
  24. temperature: float = 10000.,
  25. step: int = 2,
  26. device: Optional[torch.device] = None,
  27. ) -> torch.Tensor:
  28. exp = torch.arange(0, num_bands, step, dtype=torch.int64, device=device).to(torch.float32) / num_bands
  29. bands = 1. / (temperature ** exp)
  30. return bands
  31. def build_sincos2d_pos_embed(
  32. feat_shape: List[int],
  33. dim: int = 64,
  34. temperature: float = 10000.,
  35. reverse_coord: bool = False,
  36. interleave_sin_cos: bool = False,
  37. device: Optional[torch.device] = None,
  38. dtype: torch.dtype = torch.float32,
  39. ) -> torch.Tensor:
  40. """
  41. Args:
  42. feat_shape:
  43. dim:
  44. temperature:
  45. reverse_coord: stack grid order W, H instead of H, W
  46. interleave_sin_cos: sin, cos, sin, cos stack instead of sin, sin, cos, cos
  47. dtype:
  48. device:
  49. Returns:
  50. """
  51. assert dim % 4 == 0, 'Embed dimension must be divisible by 4 for sin-cos 2D position embedding'
  52. pos_dim = dim // 4
  53. bands = freq_bands(pos_dim, temperature=temperature, step=1, device=device)
  54. if reverse_coord:
  55. feat_shape = feat_shape[::-1] # stack W, H instead of H, W
  56. grid = torch.stack(ndgrid([
  57. torch.arange(s, device=device, dtype=torch.int64).to(torch.float32)
  58. for s in feat_shape
  59. ])).flatten(1).transpose(0, 1)
  60. pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0)
  61. # FIXME add support for unflattened spatial dim?
  62. stack_dim = 2 if interleave_sin_cos else 1 # stack sin, cos, sin, cos instead of sin sin cos cos
  63. pos_emb = torch.stack([torch.sin(pos2), torch.cos(pos2)], dim=stack_dim).flatten(1)
  64. return pos_emb.to(dtype=dtype)
  65. def swap_shape_xy(seq: List[int]) -> List[int]:
  66. if len(seq) < 2:
  67. return seq
  68. return [seq[1], seq[0]] + list(seq[2:])
  69. def build_fourier_pos_embed(
  70. feat_shape: List[int],
  71. bands: Optional[torch.Tensor] = None,
  72. num_bands: int = 64,
  73. max_res: int = 224,
  74. temperature: float = 10000.,
  75. linear_bands: bool = False,
  76. include_grid: bool = False,
  77. in_pixels: bool = True,
  78. ref_feat_shape: Optional[List[int]] = None,
  79. grid_offset: float = 0.,
  80. grid_indexing: str = 'ij',
  81. device: Optional[torch.device] = None,
  82. dtype: torch.dtype = torch.float32,
  83. ) -> List[torch.Tensor]:
  84. """
  85. Args:
  86. feat_shape: Feature shape for embedding.
  87. bands: Pre-calculated frequency bands.
  88. num_bands: Number of frequency bands (determines output dim).
  89. max_res: Maximum resolution for pixel based freq.
  90. temperature: Temperature for non-pixel freq.
  91. linear_bands: Linear band spacing for pixel based freq.
  92. include_grid: Include the spatial grid in output.
  93. in_pixels: Output in pixel freq.
  94. ref_feat_shape: Reference feature shape for resize / fine-tune.
  95. grid_offset: Constant offset to add to grid for non-pixel freq.
  96. grid_indexing: Indexing mode for meshgrid ('ij' or 'xy')
  97. dtype: Output dtype.
  98. device: Output device.
  99. Returns:
  100. """
  101. if bands is None:
  102. if in_pixels:
  103. bands = pixel_freq_bands(
  104. num_bands,
  105. float(max_res),
  106. linear_bands=linear_bands,
  107. device=device,
  108. )
  109. else:
  110. bands = freq_bands(
  111. num_bands,
  112. temperature=temperature,
  113. step=1,
  114. device=device,
  115. )
  116. else:
  117. if device is None:
  118. device = bands.device
  119. if dtype is None:
  120. dtype = bands.dtype
  121. if grid_indexing == 'xy':
  122. feat_shape = swap_shape_xy(feat_shape)
  123. if ref_feat_shape is not None:
  124. ref_feat_shape = swap_shape_xy(ref_feat_shape)
  125. if in_pixels:
  126. t = [
  127. torch.linspace(-1., 1., steps=s, device=device, dtype=torch.float32)
  128. for s in feat_shape
  129. ]
  130. else:
  131. t = [
  132. torch.arange(s, device=device, dtype=torch.int64).to(torch.float32) + grid_offset
  133. for s in feat_shape
  134. ]
  135. if ref_feat_shape is not None:
  136. # eva's scheme for resizing rope embeddings (ref shape = pretrain)
  137. t = [x / f * r for x, f, r in zip(t, feat_shape, ref_feat_shape)]
  138. grid = torch.stack(torch.meshgrid(t, indexing=grid_indexing), dim=-1)
  139. grid = grid.unsqueeze(-1)
  140. pos = grid * bands
  141. pos_sin, pos_cos = pos.sin().to(dtype=dtype), pos.cos().to(dtype=dtype)
  142. out = [grid, pos_sin, pos_cos] if include_grid else [pos_sin, pos_cos]
  143. return out
  144. class FourierEmbed(nn.Module):
  145. def __init__(
  146. self,
  147. max_res: int = 224,
  148. num_bands: int = 64,
  149. concat_grid=True,
  150. keep_spatial=False,
  151. device=None,
  152. dtype=None,
  153. ):
  154. super().__init__()
  155. self.max_res = max_res
  156. self.num_bands = num_bands
  157. self.concat_grid = concat_grid
  158. self.keep_spatial = keep_spatial
  159. self.register_buffer('bands', torch.empty(num_bands, device=device, dtype=dtype), persistent=False)
  160. # TODO: skip init when on meta device when safe to do so
  161. self.reset_parameters()
  162. def reset_parameters(self) -> None:
  163. """Initialize parameters and buffers."""
  164. self._init_buffers()
  165. def _init_buffers(self) -> None:
  166. """Compute and fill non-persistent buffer values."""
  167. self.bands.copy_(pixel_freq_bands(self.num_bands, self.max_res))
  168. def init_non_persistent_buffers(self) -> None:
  169. """Initialize non-persistent buffers."""
  170. self._init_buffers()
  171. def forward(self, x):
  172. B, C = x.shape[:2]
  173. feat_shape = x.shape[2:]
  174. emb = build_fourier_pos_embed(
  175. feat_shape,
  176. self.bands,
  177. include_grid=self.concat_grid,
  178. dtype=x.dtype,
  179. device=x.device,
  180. )
  181. emb = torch.cat(emb, dim=-1)
  182. emb = emb.transpose(-1, -2).flatten(len(feat_shape))
  183. batch_expand = (B,) + (-1,) * (x.ndim - 1)
  184. # FIXME support nD
  185. if self.keep_spatial:
  186. x = torch.cat([x, emb.unsqueeze(0).expand(batch_expand).permute(0, 3, 1, 2)], dim=1)
  187. else:
  188. x = torch.cat([x.permute(0, 2, 3, 1), emb.unsqueeze(0).expand(batch_expand)], dim=-1)
  189. x = x.reshape(B, feat_shape.numel(), -1)
  190. return x
  191. def rot(x):
  192. # x: [ x0 x1 x2 x3 x4 x5]
  193. # out: [-x1 x0 -x3 x2 -x5 x4]
  194. return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape)
  195. def rope_rotate_half(x: torch.Tensor) -> torch.Tensor:
  196. # x: [ x0 x1 x2 x3 x4 x5]
  197. # out: [-x3 -x4 -x5 x0 x1 x2]
  198. x1, x2 = x.chunk(2, dim=-1)
  199. return torch.cat([-x2, x1], dim=-1)
  200. def apply_rot_embed(
  201. x: torch.Tensor,
  202. sin_emb: torch.Tensor,
  203. cos_emb: torch.Tensor,
  204. half: bool = False,
  205. ) -> torch.Tensor:
  206. # x: [..., D], eg [x0, x1, x2, x3, x4, x5]
  207. if half:
  208. # sin: [..., D], eg [sin0, sin1, sin2, sin0, sin1, sin2]
  209. # cos: [..., D], eg [cos0, cos1, cos2, cos0, cos1, cos2
  210. # rope_rotate_half(x): eg [-x3, -x4, -x5, x0, x1, x2]
  211. return x * cos_emb + rope_rotate_half(x) * sin_emb
  212. else:
  213. # sin: [..., D], eg [sin0, sin0, sin1, sin1, sin2, sin2]
  214. # cos: [..., D], eg [cos0, cos0, cos1, cos1, cos2, cos2]
  215. # rot(x): eg [-x1, x0, -x3, x2, -x5, x4]
  216. return x * cos_emb + rot(x) * sin_emb
  217. def apply_rot_embed_list(
  218. x: List[torch.Tensor],
  219. sin_emb: torch.Tensor,
  220. cos_emb: torch.Tensor,
  221. half: bool = False
  222. ) -> List[torch.Tensor]:
  223. if isinstance(x, torch.Tensor):
  224. x = [x]
  225. # x: [..., D], eg [x0, x1, x2, x3, x4, x5]
  226. if half:
  227. # sin: [..., D], eg [sin0, sin1, sin2, sin0, sin1, sin2]
  228. # cos: [..., D], eg [cos0, cos1, cos2, cos0, cos1, cos2
  229. # rope_rotate_half(x): eg [-x3, -x4, -x5, x0, x1, x2]
  230. return [t * cos_emb + rope_rotate_half(t) * sin_emb for t in x]
  231. else:
  232. # sin: [..., D], eg [sin0, sin0, sin1, sin1, sin2, sin2]
  233. # cos: [..., D], eg [cos0, cos0, cos1, cos1, cos2, cos2]
  234. # rot(x): eg [-x1, x0, -x3, x2, -x5, x4]
  235. return [t * cos_emb + rot(t) * sin_emb for t in x]
  236. def apply_rot_embed_cat(
  237. x: torch.Tensor,
  238. emb: torch.Tensor,
  239. half: bool = False
  240. ) -> torch.Tensor:
  241. sin_emb, cos_emb = emb.chunk(2, -1)
  242. # x: [..., D], eg [x0, x1, x2, x3, x4, x5]
  243. if half:
  244. # sin: [..., D], eg [sin0, sin1, sin2, sin0, sin1, sin2]
  245. # cos: [..., D], eg [cos0, cos1, cos2, cos0, cos1, cos2
  246. # rope_rotate_half(x), eg [-x3, -x4, -x5, x0, x1, x2]
  247. return x * cos_emb + rope_rotate_half(x) * sin_emb
  248. else:
  249. # sin: [..., D], eg [sin0, sin0, sin1, sin1, sin2, sin2]
  250. # cos: [..., D], eg [cos0, cos0, cos1, cos1, cos2, cos2]
  251. # rot(x), eg [-x1, x0, -x3, x2, -x5, x4]
  252. return x * cos_emb + rot(x) * sin_emb
  253. def apply_keep_indices_nlc(
  254. x: torch.Tensor,
  255. pos_embed: torch.Tensor,
  256. keep_indices: torch.Tensor,
  257. pos_embed_has_batch: bool = False,
  258. ) -> torch.Tensor:
  259. """ Apply keep indices to different ROPE shapes
  260. Expected pos_embed shapes:
  261. * [seq_len, pos_embed_dim] --> output [batch_size, seq_len, pos_embed_dim]
  262. * [num_heads, seq_len, pos_embed_dim] --> output [batch_size, num_heads, seq_len, pos_embed_dim]
  263. * [depth, num_heads, seq_len, pos_embed_dim] --> output [batch_size, depth, num_heads, seq_len, pos_embed_dim]
  264. And all of the above with leading batch dimension already present if `pos_embed_has_batch == True`
  265. """
  266. if pos_embed_has_batch:
  267. # Pos embed already includes batch dim
  268. _assert(pos_embed.ndim >= 3, 'Incorrect number of dimensions') # At least [batch, seq_len, pos_embed_dim]
  269. else:
  270. # Add batch dimension and expand to batch size
  271. _assert(pos_embed.ndim >= 2, 'Incorrect number of dimensions') # At least [seq_len, pos_embed_dim]
  272. expand_shape = (x.shape[0],) + (-1,) * pos_embed.ndim
  273. pos_embed = pos_embed.unsqueeze(0).expand(expand_shape)
  274. # Reshape keep_indices to add singleton dims
  275. keep_shape = (keep_indices.shape[0],) + (1,) * (pos_embed.ndim - 3) + (keep_indices.shape[1], 1)
  276. keep_indices = keep_indices.view(keep_shape)
  277. # Expand all dims to match position embedding except the gather dim (second-last)
  278. keep_expand = list(pos_embed.shape)
  279. keep_expand[-2] = -1
  280. keep_indices = keep_indices.expand(keep_expand)
  281. return pos_embed.gather(-2, keep_indices)
  282. def build_rotary_pos_embed(
  283. feat_shape: List[int],
  284. bands: Optional[torch.Tensor] = None,
  285. dim: int = 64,
  286. max_res: int = 224,
  287. temperature: float = 10000.,
  288. linear_bands: bool = False,
  289. in_pixels: bool = True,
  290. ref_feat_shape: Optional[List[int]] = None,
  291. grid_offset: float = 0.,
  292. grid_indexing: str = 'ij',
  293. device: Optional[torch.device] = None,
  294. dtype: torch.dtype = torch.float32,
  295. ):
  296. """
  297. Args:
  298. feat_shape: Spatial shape of the target tensor for embedding.
  299. bands: Optional pre-generated frequency bands
  300. dim: Output dimension of embedding tensor.
  301. max_res: Maximum resolution for pixel mode.
  302. temperature: Temperature (inv freq) for non-pixel mode
  303. linear_bands: Linearly (instead of log) spaced bands for pixel mode
  304. in_pixels: Pixel vs language (inv freq) mode.
  305. ref_feat_shape: Reference feature shape for resize / fine-tune.
  306. grid_offset: Constant offset to add to grid for non-pixel freq.
  307. grid_indexing: Indexing mode for meshgrid ('ij' or 'xy')
  308. device: Output device.
  309. dtype: Output dtype.
  310. Returns:
  311. """
  312. sin_emb, cos_emb = build_fourier_pos_embed(
  313. feat_shape,
  314. bands=bands,
  315. num_bands=dim // 4,
  316. max_res=max_res,
  317. temperature=temperature,
  318. linear_bands=linear_bands,
  319. in_pixels=in_pixels,
  320. ref_feat_shape=ref_feat_shape,
  321. grid_offset=grid_offset,
  322. grid_indexing=grid_indexing,
  323. device=device,
  324. dtype=dtype,
  325. )
  326. num_spatial_dim = 1
  327. # this would be much nicer as a .numel() call to torch.Size(), but torchscript sucks
  328. for x in feat_shape:
  329. num_spatial_dim *= x
  330. sin_emb = sin_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1)
  331. cos_emb = cos_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1)
  332. return sin_emb, cos_emb
  333. class RotaryEmbedding(nn.Module):
  334. """ Rotary position embedding
  335. NOTE: This is my initial attempt at impl rotary embedding for spatial use, it has not
  336. been well tested, and will likely change. It will be moved to its own file.
  337. The following impl/resources were referenced for this impl:
  338. * https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py
  339. * https://blog.eleuther.ai/rotary-embeddings/
  340. """
  341. def __init__(
  342. self,
  343. dim,
  344. max_res=224,
  345. temperature=10000,
  346. in_pixels=True,
  347. linear_bands: bool = False,
  348. feat_shape: Optional[List[int]] = None,
  349. ref_feat_shape: Optional[List[int]] = None,
  350. grid_offset: float = 0.,
  351. grid_indexing: str = 'ij',
  352. device=None,
  353. dtype=None,
  354. ):
  355. super().__init__()
  356. self.dim = dim
  357. self.max_res = max_res
  358. self.temperature = temperature
  359. self.linear_bands = linear_bands
  360. self.in_pixels = in_pixels
  361. self.feat_shape = feat_shape
  362. self.ref_feat_shape = ref_feat_shape
  363. self.grid_offset = grid_offset
  364. self.grid_indexing = grid_indexing
  365. # Track which mode we're in
  366. self._use_cached_embed = feat_shape is not None
  367. if feat_shape is None:
  368. # bands mode: cache bands, rebuild embeddings on each get_embed call
  369. bands_shape = (dim // 4,)
  370. self.register_buffer('bands', torch.empty(bands_shape, device=device, dtype=dtype), persistent=False)
  371. self.pos_embed_sin = None
  372. self.pos_embed_cos = None
  373. else:
  374. # embed mode: cache full sin/cos embeddings
  375. self.bands = None
  376. num_pos = 1
  377. for s in feat_shape:
  378. num_pos *= s
  379. emb_shape = (num_pos, dim)
  380. self.register_buffer('pos_embed_sin', torch.empty(emb_shape, device=device, dtype=dtype), persistent=False)
  381. self.register_buffer('pos_embed_cos', torch.empty(emb_shape, device=device, dtype=dtype), persistent=False)
  382. # TODO: skip init when on meta device when safe to do so
  383. self.reset_parameters()
  384. def reset_parameters(self) -> None:
  385. """Initialize parameters and buffers."""
  386. self._init_buffers()
  387. def _init_buffers(self) -> None:
  388. """Compute and fill non-persistent buffer values."""
  389. if not self._use_cached_embed:
  390. self.bands.copy_(self._compute_bands())
  391. else:
  392. emb_sin, emb_cos = self._get_pos_embed_values(self.feat_shape)
  393. self.pos_embed_sin.copy_(emb_sin)
  394. self.pos_embed_cos.copy_(emb_cos)
  395. def _compute_bands(self, device=None, dtype=None):
  396. """Compute frequency bands."""
  397. if self.in_pixels:
  398. bands = pixel_freq_bands(
  399. self.dim // 4,
  400. float(self.max_res),
  401. linear_bands=self.linear_bands,
  402. )
  403. else:
  404. bands = freq_bands(
  405. self.dim // 4,
  406. temperature=self.temperature,
  407. step=1,
  408. )
  409. return bands.to(device=device, dtype=dtype)
  410. def _get_pos_embed_values(self, feat_shape: List[int], device=None, dtype=torch.float32):
  411. emb_sin, emb_cos = build_rotary_pos_embed(
  412. feat_shape=feat_shape,
  413. dim=self.dim,
  414. max_res=self.max_res,
  415. temperature=self.temperature,
  416. linear_bands=self.linear_bands,
  417. in_pixels=self.in_pixels,
  418. ref_feat_shape=self.ref_feat_shape,
  419. grid_offset=self.grid_offset,
  420. grid_indexing=self.grid_indexing,
  421. device=device,
  422. dtype=dtype,
  423. )
  424. return emb_sin, emb_cos
  425. def init_non_persistent_buffers(self) -> None:
  426. """Initialize non-persistent buffers."""
  427. self._init_buffers()
  428. def update_feat_shape(self, feat_shape: List[int]):
  429. if self.feat_shape is not None and feat_shape != self.feat_shape:
  430. # only update if feat_shape was set and different from previous value
  431. assert self.pos_embed_sin is not None
  432. assert self.pos_embed_cos is not None
  433. self.pos_embed_sin, self.pos_embed_cos = self._get_pos_embed_values(
  434. feat_shape,
  435. device=self.pos_embed_sin.device,
  436. dtype=self.pos_embed_sin.dtype,
  437. )
  438. self.feat_shape = feat_shape
  439. def get_embed(self, shape: Optional[List[int]] = None):
  440. if shape is not None and self.bands is not None:
  441. # rebuild embeddings every call, use if target shape changes
  442. return build_rotary_pos_embed(
  443. shape,
  444. self.bands,
  445. in_pixels=self.in_pixels,
  446. ref_feat_shape=self.ref_feat_shape,
  447. grid_offset=self.grid_offset,
  448. grid_indexing=self.grid_indexing,
  449. )
  450. elif self.pos_embed_sin is not None and self.pos_embed_cos is not None:
  451. return self.pos_embed_sin, self.pos_embed_cos
  452. else:
  453. assert False, "get_embed() requires pre-computed pos embeds or valid shape w/ pre-computed bands"
  454. def forward(self, x):
  455. # assuming channel-first tensor where spatial dim are >= 2
  456. sin_emb, cos_emb = self.get_embed(x.shape[2:])
  457. return apply_rot_embed(x, sin_emb, cos_emb)
  458. class RotaryEmbeddingCat(nn.Module):
  459. """ Rotary position embedding w/ concatenatd sin & cos
  460. The following impl/resources were referenced for this impl:
  461. * https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py
  462. * https://blog.eleuther.ai/rotary-embeddings/
  463. """
  464. def __init__(
  465. self,
  466. dim: int,
  467. max_res: int = 224,
  468. temperature: float = 10000,
  469. in_pixels: bool = True,
  470. linear_bands: bool = False,
  471. feat_shape: Optional[List[int]] = None,
  472. ref_feat_shape: Optional[List[int]] = None,
  473. grid_offset: float = 0.,
  474. grid_indexing: str = 'ij',
  475. device=None,
  476. dtype=None,
  477. ):
  478. super().__init__()
  479. self.dim = dim
  480. self.max_res = max_res
  481. self.temperature = temperature
  482. self.in_pixels = in_pixels
  483. self.linear_bands = linear_bands
  484. self.feat_shape = feat_shape
  485. self.ref_feat_shape = ref_feat_shape
  486. self.grid_offset = grid_offset
  487. self.grid_indexing = grid_indexing
  488. # Track which mode we're in
  489. self._use_cached_embed = feat_shape is not None
  490. if feat_shape is None:
  491. # bands mode: cache bands, rebuild embeddings on each get_embed call
  492. bands_shape = (dim // 4,)
  493. self.register_buffer('bands', torch.empty(bands_shape, device=device, dtype=dtype), persistent=False)
  494. self.pos_embed = None
  495. else:
  496. # embed mode: cache full embeddings
  497. self.bands = None
  498. num_pos = 1
  499. for s in feat_shape:
  500. num_pos *= s
  501. emb_shape = (num_pos, dim * 2) # concatenated sin & cos
  502. self.register_buffer('pos_embed', torch.empty(emb_shape, device=device, dtype=dtype), persistent=False)
  503. # TODO: skip init when on meta device when safe to do so
  504. self.reset_parameters()
  505. def reset_parameters(self) -> None:
  506. """Initialize parameters and buffers."""
  507. self._init_buffers()
  508. def _init_buffers(self) -> None:
  509. """Compute and fill non-persistent buffer values."""
  510. if not self._use_cached_embed:
  511. self.bands.copy_(self._compute_bands())
  512. else:
  513. self.pos_embed.copy_(self._get_pos_embed_values(self.feat_shape))
  514. def _compute_bands(self, device=None, dtype=None):
  515. """Compute frequency bands."""
  516. if self.in_pixels:
  517. bands = pixel_freq_bands(
  518. self.dim // 4,
  519. float(self.max_res),
  520. linear_bands=self.linear_bands,
  521. )
  522. else:
  523. bands = freq_bands(
  524. self.dim // 4,
  525. temperature=self.temperature,
  526. step=1,
  527. )
  528. return bands.to(device=device, dtype=dtype)
  529. def _get_pos_embed_values(self, feat_shape: List[int], device=None, dtype=torch.float32):
  530. embeds = build_rotary_pos_embed(
  531. feat_shape=feat_shape,
  532. dim=self.dim,
  533. max_res=self.max_res,
  534. temperature=self.temperature,
  535. linear_bands=self.linear_bands,
  536. in_pixels=self.in_pixels,
  537. ref_feat_shape=self.ref_feat_shape,
  538. grid_offset=self.grid_offset,
  539. grid_indexing=self.grid_indexing,
  540. device=device,
  541. dtype=dtype,
  542. )
  543. return torch.cat(embeds, -1)
  544. def init_non_persistent_buffers(self) -> None:
  545. """Initialize non-persistent buffers."""
  546. self._init_buffers()
  547. def update_feat_shape(self, feat_shape: List[int]):
  548. if self.feat_shape is not None and feat_shape != self.feat_shape:
  549. # only update if feat_shape was set and different from previous value
  550. assert self.pos_embed is not None
  551. self.pos_embed = self._get_pos_embed_values(
  552. feat_shape,
  553. device=self.pos_embed.device,
  554. dtype=self.pos_embed.dtype,
  555. )
  556. self.feat_shape = feat_shape
  557. def get_embed(self, shape: Optional[List[int]] = None):
  558. if shape is not None and self.bands is not None:
  559. # rebuild embeddings from cached bands every call, use if target shape changes
  560. embeds = build_rotary_pos_embed(
  561. shape,
  562. self.bands,
  563. in_pixels=self.in_pixels,
  564. ref_feat_shape=self.ref_feat_shape,
  565. grid_offset=self.grid_offset,
  566. grid_indexing=self.grid_indexing,
  567. )
  568. return torch.cat(embeds, -1)
  569. elif self.pos_embed is not None:
  570. return self.pos_embed
  571. else:
  572. assert False, "get_embed() requires pre-computed pos embed or valid shape w/ pre-computed bands"
  573. def get_batch_embeds(
  574. self,
  575. shapes: List[Tuple[int, int]],
  576. seq_len: Optional[int] = None,
  577. ) -> Union[torch.Tensor, List[torch.Tensor]]:
  578. """Generate ROPE embeddings for multiple grid shapes efficiently.
  579. Computes embeddings for the maximum grid size once, then extracts
  580. and flattens the relevant portions for each requested shape.
  581. Args:
  582. shapes: List of (H, W) tuples representing different grid sizes
  583. Returns:
  584. List of concatenated sin/cos embeddings for each shape,
  585. where each tensor has shape (H*W, dim)
  586. """
  587. if not shapes:
  588. return []
  589. # Check if we have pre-computed bands
  590. if self.bands is None:
  591. # If we have pre-computed pos_embed for a fixed shape, we can't do batch generation
  592. raise RuntimeError("Batch embedding generation requires cached bands, not pre-computed embeddings")
  593. # Find max dimensions across all shapes
  594. max_h = max(h for h, w in shapes)
  595. max_w = max(w for h, w in shapes)
  596. # Generate embeddings for max size ONCE
  597. sin_emb, cos_emb = build_rotary_pos_embed(
  598. feat_shape=(max_h, max_w),
  599. bands=self.bands,
  600. in_pixels=self.in_pixels,
  601. ref_feat_shape=self.ref_feat_shape,
  602. grid_offset=self.grid_offset,
  603. grid_indexing=self.grid_indexing,
  604. )
  605. # sin_emb and cos_emb are (max_h * max_w, dim//2)
  606. # concat and reshape to 2D for slicing
  607. rope_embed_2d = torch.cat([sin_emb, cos_emb], dim=-1).view(max_h, max_w, -1)
  608. if seq_len is not None:
  609. flat_embeds = torch.zeros(len(shapes), seq_len, rope_embed_2d.shape[-1]).type_as(sin_emb)
  610. for i, (h, w) in enumerate(shapes):
  611. src_len = h * w
  612. flat_embeds[i, :src_len] = rope_embed_2d[:h, :w].reshape(src_len, -1)
  613. return flat_embeds
  614. else:
  615. flat_embeds_list = [rope_embed_2d[:h, :w].reshape(h * w, -1) for h, w in shapes]
  616. return flat_embeds_list
  617. def forward(self, x):
  618. # assuming channel-first tensor where spatial dim are >= 2
  619. pos_embed = self.get_embed(x.shape[2:])
  620. return apply_rot_embed_cat(x, pos_embed)
  621. def init_random_2d_freqs(
  622. head_dim: int,
  623. depth: int,
  624. num_heads: int,
  625. temperature: float = 10.0,
  626. rotate: bool = True,
  627. *,
  628. device=None,
  629. dtype=torch.float32,
  630. ) -> torch.Tensor:
  631. """ Vectorised 2D ROPE frequencies with random rotation for mixed mode ROPE.
  632. Returns:
  633. Tensor (2, depth, num_heads, head_dim//2)
  634. """
  635. # base magnitudes, shape: (head_dim//4,)
  636. mag = 1.0 / (temperature ** (torch.arange(0, head_dim, 4, device=device, dtype=dtype) / head_dim))
  637. # (1,1,L) so it broadcasts over both depth and heads
  638. mag = mag.unsqueeze(0).unsqueeze(0) # (1,1,L)
  639. # random (or zero) rotation per head *and* per block
  640. if rotate:
  641. angles = torch.rand(depth, num_heads, 1, device=device, dtype=dtype) * 2 * torch.pi
  642. else:
  643. angles = torch.zeros(depth, num_heads, 1, device=device, dtype=dtype)
  644. # build (depth, num_heads, 2·L) == head_dim//2 on the last axis
  645. fx = torch.cat([mag * torch.cos(angles), mag * torch.cos(angles + torch.pi / 2)], dim=-1)
  646. fy = torch.cat([mag * torch.sin(angles), mag * torch.sin(angles + torch.pi / 2)], dim=-1)
  647. # (2, depth, num_heads, head_dim//2)
  648. return torch.stack([fx, fy], dim=0)
  649. @torch.fx.wrap
  650. @register_notrace_function
  651. def get_mixed_grid(
  652. shape: List[int],
  653. grid_indexing: str = 'ij',
  654. device: Optional[torch.device] = None,
  655. dtype: torch.dtype = torch.float32,
  656. ) -> Tuple[torch.Tensor, torch.Tensor]:
  657. if grid_indexing == 'xy':
  658. shape = swap_shape_xy(shape)
  659. x_pos, y_pos = torch.meshgrid(
  660. torch.arange(shape[0], device=device, dtype=torch.float32),
  661. torch.arange(shape[1], device=device, dtype=torch.float32),
  662. indexing=grid_indexing,
  663. )
  664. t_x = x_pos.to(dtype).flatten()
  665. t_y = y_pos.to(dtype).flatten()
  666. return t_x, t_y
  667. def get_mixed_freqs(
  668. freqs: torch.Tensor,
  669. t_x: torch.Tensor,
  670. t_y: torch.Tensor,
  671. ) -> torch.Tensor:
  672. """Compute mixed (learnable) frequencies."""
  673. # Create position indices
  674. dtype = freqs.dtype
  675. freqs = freqs.float()
  676. freqs_x = (t_x.unsqueeze(-1) @ freqs[0].unsqueeze(-2))
  677. freqs_y = (t_y.unsqueeze(-1) @ freqs[1].unsqueeze(-2))
  678. combined = freqs_x + freqs_y # shape: (num_heads, N, dim//4)
  679. sin_emb = torch.sin(combined).repeat_interleave(2, -1) # (N, dim//2)
  680. cos_emb = torch.cos(combined).repeat_interleave(2, -1) # (N, dim//2)
  681. rope_embeds = torch.cat([sin_emb, cos_emb], dim=-1) # (num_heads, H*W, head_dim)
  682. return rope_embeds.to(dtype)
  683. class RotaryEmbeddingMixed(nn.Module):
  684. """Rotary position embedding with depth-dependent learnable frequencies.
  685. This implementation supports mixed (learnable) ROPE. In mixed mode,
  686. each transformer block has its own set of learnable frequency parameters.
  687. Based on 'Rotary Position Embedding for Vision: https://arxiv.org/abs/2403.13298)'
  688. Compatible with original at https://github.com/naver-ai/rope-vit
  689. """
  690. def __init__(
  691. self,
  692. dim: int,
  693. depth: int,
  694. num_heads: int,
  695. temperature: float = 10.0,
  696. feat_shape: Optional[List[int]] = None,
  697. grid_indexing: str = 'xy',
  698. device=None,
  699. dtype=None,
  700. ):
  701. """Initialize rotary embeddings.
  702. Args:
  703. dim: Embedding dimension (should be divisible by 4)
  704. depth: Number of transformer blocks
  705. num_heads: Number of attention heads
  706. temperature: Base for frequency computation
  707. feat_shape: Spatial dimensions [H, W] if known in advance
  708. grid_indexing: How to index grid positions ('xy' or 'ij')
  709. """
  710. super().__init__()
  711. self.dim = dim
  712. self.depth = depth
  713. self.num_heads = num_heads
  714. self.temperature = temperature
  715. self.feat_shape = feat_shape
  716. self.grid_indexing = grid_indexing
  717. head_dim = dim // num_heads
  718. assert head_dim % 4 == 0, f"head_dim must be divisible by 4, got {head_dim}"
  719. freqs = init_random_2d_freqs(
  720. head_dim,
  721. depth,
  722. num_heads,
  723. temperature=temperature,
  724. rotate=True,
  725. device=device,
  726. dtype=dtype,
  727. ) # (2, depth, num_heads, head_dim//2)
  728. self.freqs = nn.Parameter(freqs)
  729. if feat_shape is not None:
  730. # cache pre-computed grid
  731. num_pos = 1
  732. for s in feat_shape:
  733. num_pos *= s
  734. self.register_buffer('t_x', torch.empty(num_pos, device=device, dtype=dtype), persistent=False)
  735. self.register_buffer('t_y', torch.empty(num_pos, device=device, dtype=dtype), persistent=False)
  736. # TODO: skip init when on meta device when safe to do so
  737. self._init_buffers()
  738. else:
  739. self.t_x = self.t_y = None
  740. def _init_buffers(self) -> None:
  741. """Compute and fill non-persistent buffer values."""
  742. if self.feat_shape is not None:
  743. t_x, t_y = self._get_grid_values(self.feat_shape)
  744. self.t_x.copy_(t_x)
  745. self.t_y.copy_(t_y)
  746. def reset_parameters(self) -> None:
  747. """Initialize parameters and buffers."""
  748. self._init_buffers()
  749. def _get_grid_values(self, feat_shape: Optional[List[int]]):
  750. t_x, t_y = get_mixed_grid(
  751. feat_shape,
  752. grid_indexing=self.grid_indexing,
  753. device=self.freqs.device,
  754. )
  755. return t_x, t_y
  756. def update_feat_shape(self, feat_shape: Optional[List[int]]):
  757. if self.feat_shape is not None and feat_shape != self.feat_shape:
  758. assert self.t_x is not None
  759. assert self.t_y is not None
  760. t_x, t_y = self._get_grid_values(feat_shape)
  761. self.t_x = t_x.to(self.t_x.device, self.t_x.dtype)
  762. self.t_y = t_y.to(self.t_y.device, self.t_y.dtype)
  763. self.feat_shape = feat_shape
  764. def init_non_persistent_buffers(self) -> None:
  765. """Initialize non-persistent buffers."""
  766. self._init_buffers()
  767. def get_embed(self, shape: Optional[List[int]] = None) -> torch.Tensor:
  768. """Generate rotary embeddings for the given spatial shape.
  769. Args:
  770. shape: Spatial dimensions [H, W]
  771. Returns:
  772. Tensor of shape (depth, H*W, dim) containing concatenated sin/cos embeddings
  773. """
  774. if shape is not None:
  775. t_x, t_y = get_mixed_grid(
  776. shape,
  777. grid_indexing=self.grid_indexing,
  778. device=self.freqs.device
  779. )
  780. elif self.t_x is not None and self.t_y is not None:
  781. t_x, t_y = self.t_x, self.t_y
  782. else:
  783. assert False, "get_embed() requires pre-computed t_x/t_y or valid shape"
  784. return get_mixed_freqs(self.freqs, t_x, t_y)
  785. def get_batch_embeds(
  786. self,
  787. shapes: List[Tuple[int, int]],
  788. seq_len: Optional[int] = None,
  789. ) -> Union[torch.Tensor, List[torch.Tensor]]:
  790. """Generate ROPE embeddings for multiple grid shapes efficiently.
  791. Computes embeddings for the maximum grid size once, then extracts
  792. and flattens the relevant portions for each requested shape.
  793. Args:
  794. shapes: List of (H, W) tuples representing different grid sizes
  795. seq_len: If provided, return padded tensor of this length. Otherwise return list.
  796. Returns:
  797. If seq_len is provided: Padded tensor of shape (len(shapes), depth, num_heads, seq_len, dim)
  798. Otherwise: List of tensors with shape (depth, num_heads, H*W, dim) for each shape
  799. """
  800. if not shapes:
  801. return []
  802. # Find max dimensions
  803. max_h = max(h for h, w in shapes)
  804. max_w = max(w for h, w in shapes)
  805. # Generate embeddings for max size ONCE
  806. t_x, t_y = get_mixed_grid(
  807. [max_h, max_w],
  808. grid_indexing=self.grid_indexing,
  809. device=self.freqs.device
  810. )
  811. max_embed = get_mixed_freqs(self.freqs, t_x, t_y) # (depth, num_heads, max_h*max_w, dim)
  812. # Reshape to 2D grid for easy slicing
  813. depth, num_heads, _, dim = max_embed.shape
  814. max_embed_2d = max_embed.view(depth, num_heads, max_h, max_w, dim)
  815. if seq_len is not None:
  816. # Return padded tensor
  817. B = len(shapes)
  818. padded = torch.zeros(B, depth, num_heads, seq_len, dim, device=self.freqs.device, dtype=self.freqs.dtype)
  819. for i, (h, w) in enumerate(shapes):
  820. # Slice and flatten
  821. embed_slice = max_embed_2d[:, :, :h, :w].reshape(depth, num_heads, h * w, dim)
  822. actual_len = h * w
  823. padded[i, :, :, :actual_len] = embed_slice
  824. return padded
  825. else:
  826. # Return list
  827. results = []
  828. for h, w in shapes:
  829. # Slice and flatten
  830. embed_slice = max_embed_2d[:, :, :h, :w].reshape(depth, num_heads, h * w, dim)
  831. results.append(embed_slice)
  832. return results
  833. def forward(self, x):
  834. # assuming channel-first tensor where spatial dim are >= 2
  835. pos_embed = self.get_embed(x.shape[2:])
  836. return apply_rot_embed_cat(x, pos_embed)
  837. def no_weight_decay(self):
  838. """Exclude frequency parameters from weight decay."""
  839. return {'freqs'}
  840. @torch.fx.wrap
  841. @register_notrace_function
  842. def make_coords_dinov3(
  843. height: int,
  844. width: int,
  845. normalize_coords: str = 'separate',
  846. grid_indexing: str = 'ij',
  847. grid_offset: float = 0.,
  848. device: torch.device = 'cpu',
  849. dtype: torch.dtype = torch.float32,
  850. ) -> torch.Tensor:
  851. """Make coordinate grid matching offset and normalization of original.
  852. Returns: coords with shape (HW, 2) in [-1, 1].
  853. """
  854. # 0.5-centered indices with optional offset
  855. coords_h = torch.arange(0.5, height, device=device, dtype=torch.float32) + grid_offset
  856. coords_w = torch.arange(0.5, width, device=device, dtype=torch.float32) + grid_offset
  857. # Normalization denominators
  858. if normalize_coords == "max":
  859. denom = float(max(height, width))
  860. h_denom = denom
  861. w_denom = denom
  862. elif normalize_coords == "min":
  863. denom = float(min(height, width))
  864. h_denom = denom
  865. w_denom = denom
  866. elif normalize_coords == "separate":
  867. h_denom = float(height)
  868. w_denom = float(width)
  869. else:
  870. raise ValueError(f"Unknown normalize_coords: {normalize_coords}")
  871. # Normalize to [0, 1]
  872. coords_h = coords_h / h_denom
  873. coords_w = coords_w / w_denom
  874. coords_h = coords_h.to(dtype)
  875. coords_w = coords_w.to(dtype)
  876. # Create grid then map to [-1, 1]
  877. if grid_indexing == "xy":
  878. grid_w, grid_h = torch.meshgrid(coords_w, coords_h, indexing="xy")
  879. coords = torch.stack([grid_h, grid_w], dim=-1) # (H, W, 2) -> (h, w order)
  880. else:
  881. coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) # (H, W, 2)
  882. coords = coords.flatten(0, 1) # (HW, 2)
  883. coords = 2.0 * coords - 1.0 # (H, W, 2) in [-1, 1]
  884. return coords
  885. class RotaryEmbeddingDinoV3(nn.Module):
  886. """RoPE for timm DinoV3 port, numerically matching original.
  887. Math is aligned to original DinoV3 RopePositionEmbedding at https://github.com/facebookresearch/dinov3:
  888. - 0.5-centered coords normalized by H/W (or min/max), mapped to [-1,1]
  889. - training-time augmentations (shift/jitter/rescale)
  890. - periods schedule equals Rope's temperature (base) or min/max period
  891. """
  892. def __init__(
  893. self,
  894. dim: int,
  895. temperature: Optional[float] = 100.0,
  896. min_period: Optional[float] = None,
  897. max_period: Optional[float] = None,
  898. feat_shape: Optional[List[int]] = None,
  899. normalize_coords: str = "separate", # 'min', 'max', 'separate'
  900. grid_offset: float = 0.0,
  901. grid_indexing: str = "ij",
  902. rotate_half: bool = True,
  903. shift_coords: Optional[float] = None,
  904. jitter_coords: Optional[float] = None, # interpreted as factor J >= 1
  905. rescale_coords: Optional[float] = None, # interpreted as factor R >= 1
  906. device=None,
  907. dtype=None,
  908. ):
  909. super().__init__()
  910. # Dimensions / output format
  911. self.dim = dim # equal to head_dim for most vit applications
  912. self.rotate_half = rotate_half
  913. # Period schedule parameters
  914. self.temperature = float(temperature)
  915. self.min_period = min_period
  916. self.max_period = max_period
  917. # Coord processing + augs
  918. self.normalize_coords = normalize_coords
  919. self.shift_coords = shift_coords
  920. self.jitter_coords = jitter_coords
  921. self.rescale_coords = rescale_coords
  922. self.aug_active = any([a is not None for a in [self.shift_coords, self.jitter_coords, self.rescale_coords]])
  923. # Grid config
  924. self.feat_shape = feat_shape
  925. self.grid_offset = grid_offset
  926. self.grid_indexing = grid_indexing
  927. # Register empty buffer for periods
  928. periods_shape = (dim // 4,)
  929. self.register_buffer("periods", torch.empty(periods_shape, device=device, dtype=dtype), persistent=False)
  930. if feat_shape is not None:
  931. # Register empty buffer for cached embeddings
  932. num_pos = feat_shape[0] * feat_shape[1]
  933. emb_shape = (num_pos, dim * 2) # concatenated sin & cos
  934. self.register_buffer("pos_embed_cached", torch.empty(emb_shape, device=device, dtype=dtype), persistent=False)
  935. else:
  936. self.pos_embed_cached = None
  937. # TODO: skip init when on meta device when safe to do so
  938. self.reset_parameters()
  939. def reset_parameters(self) -> None:
  940. """Initialize parameters and buffers."""
  941. self._init_buffers()
  942. def _init_buffers(self) -> None:
  943. """Compute and fill non-persistent buffer values."""
  944. self.periods.copy_(self._compute_periods())
  945. if self.feat_shape is not None and self.pos_embed_cached is not None:
  946. rope_embed = self._create_embed(self.feat_shape, no_aug=True)
  947. self.pos_embed_cached.copy_(rope_embed)
  948. def _compute_periods(self, device: torch.device = 'cpu', dtype: torch.dtype = torch.float32) -> torch.Tensor:
  949. """Construct periods from either min/max or temperature."""
  950. dim = self.dim // 4
  951. if self.min_period is not None and self.max_period is not None:
  952. exponents = torch.linspace(0, 1, dim, device='cpu', dtype=torch.float32)
  953. periods = self.min_period * ((self.max_period / self.min_period) ** exponents)
  954. else:
  955. if self.temperature is None:
  956. raise ValueError("Provide either min/max periods or `temperature`.")
  957. exponents = 2.0 * torch.arange(dim, device='cpu', dtype=torch.float32) / (self.dim // 2)
  958. periods = self.temperature ** exponents
  959. # NOTE: The original dinv3 model weights have periods downcast to bfloat16 in persistent buffers,
  960. # loaded models will differ a bit vs timm as periods is not persistent and generated in float32 by default
  961. return periods.to(device=device, dtype=dtype)
  962. def _apply_coord_augs(self, coords: torch.Tensor) -> torch.Tensor:
  963. """Apply shift/jitter/rescale train time augmentations."""
  964. if not self.training or not self.aug_active:
  965. return coords
  966. device = coords.device
  967. dtype = coords.dtype
  968. # Shift per-axis in [-s, +s]
  969. if self.shift_coords is not None:
  970. shift = float(self.shift_coords)
  971. shift_hw = torch.empty(2, device=device, dtype=dtype).uniform_(-shift, shift)
  972. coords = coords + shift_hw[None, :]
  973. # Jitter: per-axis log-uniform factor in [1/J, J]
  974. if self.jitter_coords is not None:
  975. jitter_factor = float(self.jitter_coords)
  976. if jitter_factor <= 0:
  977. raise ValueError("jitter_coords must be > 0 (interpreted as multiplicative factor).")
  978. jitter_max = math.log(jitter_factor)
  979. jitter_hw = torch.empty(2, device=device, dtype=dtype).uniform_(-jitter_max, jitter_max).exp()
  980. coords = coords * jitter_hw[None, :]
  981. # Rescale: shared scalar log-uniform factor in [1/R, R]
  982. if self.rescale_coords is not None:
  983. rescale_factor = float(self.rescale_coords)
  984. if rescale_factor <= 0:
  985. raise ValueError("rescale_coords must be > 0 (interpreted as multiplicative factor).")
  986. rescale_max = math.log(rescale_factor)
  987. rescale = torch.empty(1, device=device, dtype=dtype).uniform_(-rescale_max, rescale_max).exp()
  988. coords = coords * rescale
  989. return coords
  990. def _get_pos_embed_from_coords(self, coords: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  991. """Return sin/cos embeddings with either 'half' or 'interleaved' layout."""
  992. # coords: (HW, 2); periods: (dim)
  993. dim = self.dim // 4
  994. device = self.periods.device
  995. dtype = self.periods.dtype
  996. assert self.periods.numel() == dim
  997. # NOTE this is a slightly later device/dtype switch than original
  998. coords = coords[:, :, None].to(device=device, dtype=dtype)
  999. angles = 2 * math.pi * coords / self.periods[None, None, :]
  1000. angles = angles.flatten(1) # (HW, dim // 2)
  1001. if self.rotate_half:
  1002. # Tile (half layout) (HW, dim // 2) -> (HW, dim)
  1003. angles = angles.tile(2)
  1004. else:
  1005. # Interleaved layout (HW, dim // 2) -> (HW, dim)
  1006. angles = angles.repeat_interleave(2, dim=-1)
  1007. sin = torch.sin(angles)
  1008. cos = torch.cos(angles)
  1009. return sin, cos
  1010. def _create_embed(
  1011. self,
  1012. feat_shape: List[int],
  1013. no_aug: bool = False,
  1014. ) -> torch.Tensor:
  1015. H, W = feat_shape
  1016. coords = make_coords_dinov3(
  1017. H, W,
  1018. normalize_coords=self.normalize_coords,
  1019. grid_indexing=self.grid_indexing,
  1020. grid_offset=self.grid_offset,
  1021. ) # (HW, 2)
  1022. if not no_aug:
  1023. coords = self._apply_coord_augs(coords)
  1024. sin, cos = self._get_pos_embed_from_coords(coords) # 2 * (HW, dim)
  1025. rope_embed = torch.cat([sin, cos], dim=-1) # (HW, 2*dim)
  1026. return rope_embed
  1027. def _cache_embed(self, feat_shape: List[int]):
  1028. # create non-augmented embeds for cache
  1029. rope_embed = self._create_embed(feat_shape, no_aug=True)
  1030. self.register_buffer("pos_embed_cached", rope_embed, persistent=False)
  1031. self.feat_shape = feat_shape
  1032. def update_feat_shape(self, feat_shape: List[int]):
  1033. if self.feat_shape is not None and feat_shape != self.feat_shape:
  1034. # only update if feat_shape was set (valid cache) and different from previous value
  1035. self._cache_embed(feat_shape)
  1036. def init_non_persistent_buffers(self) -> None:
  1037. """Initialize non-persistent buffers."""
  1038. self._init_buffers()
  1039. def get_embed(self, shape: Optional[List[int]] = None) -> torch.Tensor:
  1040. """Generate rope_embed matching DINOv3 RopePositionEmbedding numerics.
  1041. Returns: (HW, num_heads, 2 * head_dim) with last dim = [sin, cos] cat.
  1042. """
  1043. if shape is not None:
  1044. rope_embed = self._create_embed(shape)
  1045. else:
  1046. need_create = self.pos_embed_cached is None or (self.training and self.aug_active)
  1047. if need_create:
  1048. assert self.feat_shape is not None, 'feature shape must be cached on create'
  1049. rope_embed = self._create_embed(self.feat_shape)
  1050. else:
  1051. assert self.pos_embed_cached is not None
  1052. rope_embed = self.pos_embed_cached
  1053. return rope_embed
  1054. def forward(self, x: torch.Tensor) -> torch.Tensor:
  1055. """Get and apply rotary embeddings to x"""
  1056. # assuming channel-first tensor where spatial dim are >= 2
  1057. pos_embed = self.get_embed(x.shape[2:])
  1058. return apply_rot_embed_cat(x, pos_embed, half=self.rotate_half)
  1059. def create_rope_embed(
  1060. rope_type: str = 'cat',
  1061. dim: int = 768,
  1062. num_heads: int = 12,
  1063. **kwargs
  1064. ) -> nn.Module:
  1065. """Factory function for creating rotary position embeddings.
  1066. Args:
  1067. rope_type: Type of RoPE to create. Options:
  1068. - 'base': Basic RotaryEmbedding
  1069. - 'cat': RotaryEmbeddingCat (concatenated sin/cos)
  1070. - 'mixed': RotaryEmbeddingMixed (learnable per-depth frequencies)
  1071. - 'dinov3': RotaryEmbeddingDinoV3 (with coordinate transforms)
  1072. dim: Total embedding dimension
  1073. num_heads: Number of attention heads
  1074. **kwargs: Additional arguments passed to the specific RoPE class
  1075. Returns:
  1076. Rotary embedding module
  1077. """
  1078. if rope_type == 'base':
  1079. kwargs.pop('rotate_half', None) # doesn't support
  1080. return RotaryEmbedding(dim=dim // num_heads, **kwargs)
  1081. elif rope_type == 'cat':
  1082. kwargs.pop('rotate_half', None) # doesn't support
  1083. return RotaryEmbeddingCat(dim=dim // num_heads, **kwargs)
  1084. elif rope_type == 'mixed':
  1085. # Mixed requires depth parameter, generates differing embeddings per layer and head
  1086. kwargs.pop('in_pixels', None) # doesn't support
  1087. kwargs.pop('ref_feat_shape', None) # doesn't support
  1088. return RotaryEmbeddingMixed(dim=dim, num_heads=num_heads, **kwargs)
  1089. elif rope_type == 'dinov3':
  1090. kwargs.pop('in_pixels', None) # doesn't support
  1091. kwargs.pop('ref_feat_shape', None) # doesn't support
  1092. return RotaryEmbeddingDinoV3(dim=dim // num_heads, **kwargs)
  1093. else:
  1094. raise ValueError(f"Unknown RoPE type: {rope_type}")