pos_embed_rel.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587
  1. """ Relative position embedding modules and functions
  2. Hacked together by / Copyright 2022 Ross Wightman
  3. """
  4. import math
  5. import os
  6. from typing import Optional, Tuple
  7. import torch
  8. import torch.nn as nn
  9. import torch.nn.functional as F
  10. from .grid import ndgrid
  11. from .interpolate import RegularGridInterpolator
  12. from .mlp import Mlp
  13. from .weight_init import trunc_normal_
  14. _USE_SCIPY = int(os.environ.get('TIMM_USE_SCIPY_INTERP', 0)) > 0
  15. def gen_relative_position_index(
  16. q_size: Tuple[int, int],
  17. k_size: Optional[Tuple[int, int]] = None,
  18. class_token: bool = False,
  19. device=None,
  20. ) -> torch.Tensor:
  21. # Adapted with significant modifications from Swin / BeiT codebases
  22. # get pair-wise relative position index for each token inside the window
  23. assert k_size is None, 'Different q & k sizes not currently supported' # FIXME
  24. coords = torch.stack(ndgrid(
  25. torch.arange(q_size[0], device=device),
  26. torch.arange(q_size[1], device=device),
  27. )).flatten(1) # 2, Wh, Ww
  28. relative_coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww
  29. relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2
  30. relative_coords[:, :, 0] += q_size[0] - 1 # shift to start from 0
  31. relative_coords[:, :, 1] += q_size[1] - 1
  32. relative_coords[:, :, 0] *= 2 * q_size[1] - 1
  33. num_relative_distance = (2 * q_size[0] - 1) * (2 * q_size[1] - 1)
  34. # else:
  35. # # FIXME different q vs k sizes is a WIP, need to better offset the two grids?
  36. # q_coords = torch.stack(
  37. # ndgrid(
  38. # torch.arange(q_size[0]),
  39. # torch.arange(q_size[1])
  40. # )
  41. # ).flatten(1) # 2, Wh, Ww
  42. # k_coords = torch.stack(
  43. # ndgrid(
  44. # torch.arange(k_size[0]),
  45. # torch.arange(k_size[1])
  46. # )
  47. # ).flatten(1)
  48. # relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww
  49. # relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2
  50. # relative_coords[:, :, 0] += max(q_size[0], k_size[0]) - 1 # shift to start from 0
  51. # relative_coords[:, :, 1] += max(q_size[1], k_size[1]) - 1
  52. # relative_coords[:, :, 0] *= k_size[1] + q_size[1] - 1
  53. # relative_position_index = relative_coords.sum(-1) # Qh*Qw, Kh*Kw
  54. # num_relative_distance = (q_size[0] + k_size[0] - 1) * (q_size[1] + k_size[1] - 1) + 3
  55. relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
  56. if class_token:
  57. # handle cls to token & token 2 cls & cls to cls as per beit for rel pos bias
  58. # NOTE not intended or tested with MLP log-coords
  59. relative_position_index = F.pad(relative_position_index, [1, 0, 1, 0])
  60. relative_position_index[0, 0:] = num_relative_distance
  61. relative_position_index[0:, 0] = num_relative_distance + 1
  62. relative_position_index[0, 0] = num_relative_distance + 2
  63. return relative_position_index.contiguous()
  64. def resize_rel_pos_bias_table_simple(
  65. rel_pos_bias,
  66. new_window_size: Tuple[int, int],
  67. new_bias_shape: Tuple[int, ...],
  68. ):
  69. dst_size = (new_window_size[0] * 2 - 1, new_window_size[1] * 2 - 1)
  70. if rel_pos_bias.ndim == 3:
  71. # TF maxvit style (num_heads, H, W) bias shape, no extra tokens currently supported
  72. _, dst_h, dst_w = new_bias_shape
  73. num_attn_heads, src_h, src_w = rel_pos_bias.shape
  74. assert dst_h == dst_size[0] and dst_w == dst_size[1]
  75. if src_h != dst_h or src_w != dst_w:
  76. rel_pos_bias = torch.nn.functional.interpolate(
  77. rel_pos_bias.unsqueeze(0),
  78. size=dst_size,
  79. mode="bicubic",
  80. align_corners=False,
  81. ).squeeze(0)
  82. else:
  83. assert rel_pos_bias.ndim == 2
  84. # (num_pos, num_heads) (aka flat) bias shape
  85. dst_num_pos, _ = new_bias_shape
  86. src_num_pos, num_attn_heads = rel_pos_bias.shape
  87. num_extra_tokens = dst_num_pos - (dst_size[0] * dst_size[1])
  88. src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
  89. src_size = (src_size, src_size) # FIXME could support non-equal src if argument passed
  90. if src_size[0] != dst_size[0] or src_size[1] != dst_size[1]:
  91. if num_extra_tokens:
  92. extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
  93. rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
  94. else:
  95. extra_tokens = None
  96. rel_pos_bias = torch.nn.functional.interpolate(
  97. rel_pos_bias.transpose(1, 0).reshape((1, -1, src_size[0], src_size[1])),
  98. size=dst_size,
  99. mode="bicubic",
  100. align_corners=False,
  101. ).view(-1, dst_num_pos - num_extra_tokens).transpose(0, 1)
  102. if extra_tokens is not None:
  103. rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
  104. return rel_pos_bias
  105. def resize_rel_pos_bias_table_levit(
  106. position_bias_table,
  107. new_size,
  108. interpolation: str = 'bicubic',
  109. antialias: bool = True,
  110. ):
  111. """
  112. Resample relative position bias table suggested in LeVit
  113. Adapted from: https://github.com/microsoft/Cream/blob/main/TinyViT/utils.py
  114. """
  115. L1, nH1 = position_bias_table.size()
  116. L2, nH2 = new_size
  117. assert nH1 == nH2
  118. if L1 != L2:
  119. orig_dtype = position_bias_table.dtype
  120. position_bias_table = position_bias_table.float()
  121. # bicubic interpolate relative_position_bias_table if not match
  122. S1 = int(L1 ** 0.5)
  123. S2 = int(L2 ** 0.5)
  124. relative_position_bias_table_resized = F.interpolate(
  125. position_bias_table.permute(1, 0).view(1, nH1, S1, S1),
  126. size=(S2, S2),
  127. mode=interpolation,
  128. antialias=antialias,
  129. )
  130. relative_position_bias_table_resized = relative_position_bias_table_resized.view(nH2, L2).permute(1, 0)
  131. relative_position_bias_table_resized.to(orig_dtype)
  132. return relative_position_bias_table_resized
  133. else:
  134. return position_bias_table
  135. def resize_rel_pos_bias_table(
  136. rel_pos_bias,
  137. new_window_size: Tuple[int, int],
  138. new_bias_shape: Tuple[int, ...],
  139. ):
  140. """ Resize relative position bias table using more advanced interpolation.
  141. Modified from code in Microsoft Unilm (https://github.com/microsoft/unilm) repo (BeiT, BeiT-v2, etc).
  142. https://github.com/microsoft/unilm/blob/5255d52de86dad642810f5849dd357769346c1d7/beit/run_class_finetuning.py#L351
  143. Args:
  144. rel_pos_bias:
  145. new_window_size:
  146. new_bias_shape:
  147. Returns:
  148. """
  149. if _USE_SCIPY:
  150. from scipy import interpolate
  151. dst_size = (new_window_size[0] * 2 - 1, new_window_size[1] * 2 - 1)
  152. if rel_pos_bias.ndim == 3:
  153. # TF maxvit style (num_heads, H, W) bias shape, no extra tokens currently supported
  154. num_extra_tokens = 0
  155. _, dst_h, dst_w = new_bias_shape
  156. assert dst_h == dst_size[0] and dst_w == dst_size[1]
  157. num_attn_heads, src_h, src_w = rel_pos_bias.shape
  158. src_size = (src_h, src_w)
  159. has_flat_shape = False
  160. else:
  161. assert rel_pos_bias.ndim == 2
  162. # (num_pos, num_heads) (aka flat) bias shape
  163. dst_num_pos, _ = new_bias_shape
  164. src_num_pos, num_attn_heads = rel_pos_bias.shape
  165. num_extra_tokens = dst_num_pos - (dst_size[0] * dst_size[1])
  166. src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
  167. src_size = (src_size, src_size)
  168. has_flat_shape = True
  169. if src_size[0] != dst_size[0] or src_size[1] != dst_size[1]:
  170. # print("Interpolating position from %dx%d to %dx%d" % (src_size[0], src_size[1], dst_size[0], dst_size[1]))
  171. if num_extra_tokens:
  172. extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
  173. rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
  174. else:
  175. extra_tokens = None
  176. def geometric_progression(a, r, n):
  177. return a * (1.0 - r ** n) / (1.0 - r)
  178. def _calc(src, dst):
  179. left, right = 1.01, 1.5
  180. while right - left > 1e-6:
  181. q = (left + right) / 2.0
  182. gp = geometric_progression(1, q, src // 2)
  183. if gp > dst // 2:
  184. right = q
  185. else:
  186. left = q
  187. dis = []
  188. cur = 1
  189. for i in range(src // 2):
  190. dis.append(cur)
  191. cur += q ** (i + 1)
  192. r_ids = [-_ for _ in reversed(dis)]
  193. return r_ids + [0] + dis
  194. y = _calc(src_size[0], dst_size[0])
  195. x = _calc(src_size[1], dst_size[1])
  196. yx = [torch.tensor(y), torch.tensor(x)]
  197. # print("Original positions = %s" % str(x))
  198. ty = dst_size[0] // 2.0
  199. tx = dst_size[1] // 2.0
  200. dy = torch.arange(-ty, ty + 0.1, 1.0)
  201. dx = torch.arange(-tx, tx + 0.1, 1.0)
  202. dyx = ndgrid(dy, dx)
  203. # print("Target positions = %s" % str(dx))
  204. all_rel_pos_bias = []
  205. for i in range(num_attn_heads):
  206. if has_flat_shape:
  207. z = rel_pos_bias[:, i].view(src_size[0], src_size[1]).float()
  208. else:
  209. z = rel_pos_bias[i, :, :].float()
  210. if _USE_SCIPY:
  211. # Original beit code uses scipy w/ cubic interpolation
  212. f = interpolate.interp2d(x, y, z.numpy(), kind='cubic')
  213. r = torch.Tensor(f(dx, dy)).contiguous().to(rel_pos_bias.device)
  214. else:
  215. # Without scipy dependency, I've found a reasonably simple impl
  216. # that supports uneven spaced interpolation pts with 'linear' interp.
  217. # Results are comparable to scipy for model accuracy in most cases.
  218. f = RegularGridInterpolator(yx, z)
  219. r = f(dyx).contiguous().to(rel_pos_bias.device)
  220. if has_flat_shape:
  221. r = r.view(-1, 1)
  222. all_rel_pos_bias.append(r)
  223. if has_flat_shape:
  224. rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
  225. else:
  226. rel_pos_bias = torch.cat(all_rel_pos_bias, dim=0)
  227. if extra_tokens is not None:
  228. assert has_flat_shape
  229. rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
  230. return rel_pos_bias
  231. class RelPosBias(nn.Module):
  232. """ Relative Position Bias
  233. Adapted from Swin-V1 relative position bias impl, modularized.
  234. """
  235. def __init__(
  236. self,
  237. window_size: Tuple[int, int],
  238. num_heads: int,
  239. prefix_tokens: int = 0,
  240. device=None,
  241. dtype=None,
  242. ):
  243. dd = {'device': device, 'dtype': dtype}
  244. super().__init__()
  245. assert prefix_tokens <= 1
  246. self.window_size = window_size
  247. self.window_area = window_size[0] * window_size[1]
  248. self.prefix_tokens = prefix_tokens
  249. self.bias_shape = (self.window_area + prefix_tokens,) * 2 + (num_heads,)
  250. num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 * prefix_tokens
  251. self.relative_position_bias_table = nn.Parameter(torch.empty(num_relative_distance, num_heads, **dd))
  252. index_size = (self.window_area + prefix_tokens) ** 2
  253. self.register_buffer(
  254. "relative_position_index",
  255. torch.empty(index_size, device=device, dtype=torch.long),
  256. persistent=False,
  257. )
  258. # TODO: skip init when on meta device when safe to do so
  259. self.reset_parameters()
  260. def reset_parameters(self) -> None:
  261. """Initialize parameters and buffers."""
  262. trunc_normal_(self.relative_position_bias_table, std=.02)
  263. self._init_buffers()
  264. def _init_buffers(self) -> None:
  265. """Compute and fill non-persistent buffer values."""
  266. self.relative_position_index.copy_(
  267. gen_relative_position_index(
  268. self.window_size,
  269. class_token=self.prefix_tokens > 0,
  270. device=self.relative_position_index.device,
  271. ).view(-1)
  272. )
  273. def get_bias(self) -> torch.Tensor:
  274. relative_position_bias = self.relative_position_bias_table[self.relative_position_index]
  275. # win_h * win_w, win_h * win_w, num_heads
  276. relative_position_bias = relative_position_bias.view(self.bias_shape).permute(2, 0, 1)
  277. return relative_position_bias.unsqueeze(0).contiguous()
  278. def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
  279. return attn + self.get_bias()
  280. def init_non_persistent_buffers(self) -> None:
  281. """Initialize non-persistent buffers."""
  282. self._init_buffers()
  283. def gen_relative_log_coords(
  284. win_size: Tuple[int, int],
  285. pretrained_win_size: Tuple[int, int] = (0, 0),
  286. mode='swin',
  287. device=None,
  288. dtype=None,
  289. ):
  290. assert mode in ('swin', 'cr')
  291. # as per official swin-v2 impl, supporting timm specific 'cr' log coords as well
  292. relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0], device=device).to(torch.float32)
  293. relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1], device=device).to(torch.float32)
  294. relative_coords_table = torch.stack(ndgrid(relative_coords_h, relative_coords_w))
  295. relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous() # 2*Wh-1, 2*Ww-1, 2
  296. if mode == 'swin':
  297. if pretrained_win_size[0] > 0:
  298. relative_coords_table[:, :, 0] /= (pretrained_win_size[0] - 1)
  299. relative_coords_table[:, :, 1] /= (pretrained_win_size[1] - 1)
  300. else:
  301. relative_coords_table[:, :, 0] /= (win_size[0] - 1)
  302. relative_coords_table[:, :, 1] /= (win_size[1] - 1)
  303. relative_coords_table *= 8 # normalize to -8, 8
  304. relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
  305. 1.0 + relative_coords_table.abs()) / math.log2(8)
  306. else:
  307. # mode == 'cr'
  308. relative_coords_table = torch.sign(relative_coords_table) * torch.log(
  309. 1.0 + relative_coords_table.abs())
  310. return relative_coords_table.to(dtype)
  311. class RelPosMlp(nn.Module):
  312. """ Log-Coordinate Relative Position MLP
  313. Based on ideas presented in Swin-V2 paper (https://arxiv.org/abs/2111.09883)
  314. This impl covers the 'swin' implementation as well as two timm specific modes ('cr', and 'rw')
  315. """
  316. def __init__(
  317. self,
  318. window_size: Tuple[int, int],
  319. num_heads: int = 8,
  320. hidden_dim: int = 128,
  321. prefix_tokens: int = 0,
  322. mode: str = 'cr',
  323. pretrained_window_size: Tuple[int, int] = (0, 0),
  324. device=None,
  325. dtype=None,
  326. ):
  327. dd = {'device': device, 'dtype': dtype}
  328. super().__init__()
  329. self.window_size = window_size
  330. self.window_area = self.window_size[0] * self.window_size[1]
  331. self.prefix_tokens = prefix_tokens
  332. self.num_heads = num_heads
  333. self.bias_shape = (self.window_area,) * 2 + (num_heads,)
  334. self.mode = mode
  335. self.pretrained_window_size = pretrained_window_size
  336. if mode == 'swin':
  337. self.bias_act = nn.Sigmoid()
  338. self.bias_gain = 16
  339. mlp_bias = (True, False)
  340. else:
  341. self.bias_act = nn.Identity()
  342. self.bias_gain = None
  343. mlp_bias = True
  344. self.mlp = Mlp(
  345. 2, # x, y
  346. hidden_features=hidden_dim,
  347. out_features=num_heads,
  348. act_layer=nn.ReLU,
  349. bias=mlp_bias,
  350. drop=(0.125, 0.),
  351. **dd,
  352. )
  353. index_size = self.window_area ** 2
  354. rel_coords_shape = (2 * window_size[0] - 1, 2 * window_size[1] - 1, 2)
  355. self.register_buffer(
  356. "relative_position_index",
  357. torch.empty(index_size, device=device, dtype=torch.long),
  358. persistent=False,
  359. )
  360. self.register_buffer(
  361. "rel_coords_log",
  362. torch.empty(rel_coords_shape, **dd),
  363. persistent=False,
  364. )
  365. # TODO: skip init when on meta device when safe to do so
  366. self.reset_parameters()
  367. def get_bias(self) -> torch.Tensor:
  368. relative_position_bias = self.mlp(self.rel_coords_log)
  369. if self.relative_position_index is not None:
  370. relative_position_bias = relative_position_bias.view(-1, self.num_heads)[self.relative_position_index]
  371. relative_position_bias = relative_position_bias.view(self.bias_shape)
  372. relative_position_bias = relative_position_bias.permute(2, 0, 1)
  373. relative_position_bias = self.bias_act(relative_position_bias)
  374. if self.bias_gain is not None:
  375. relative_position_bias = self.bias_gain * relative_position_bias
  376. if self.prefix_tokens:
  377. relative_position_bias = F.pad(relative_position_bias, [self.prefix_tokens, 0, self.prefix_tokens, 0])
  378. return relative_position_bias.unsqueeze(0).contiguous()
  379. def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
  380. return attn + self.get_bias()
  381. def reset_parameters(self) -> None:
  382. """Initialize parameters and buffers."""
  383. self._init_buffers()
  384. def _init_buffers(self) -> None:
  385. """Compute and fill non-persistent buffer values."""
  386. device = self.relative_position_index.device
  387. dtype = self.rel_coords_log.dtype
  388. self.relative_position_index.copy_(
  389. gen_relative_position_index(self.window_size, device=device).view(-1)
  390. )
  391. self.rel_coords_log.copy_(
  392. gen_relative_log_coords(
  393. self.window_size,
  394. self.pretrained_window_size,
  395. mode=self.mode,
  396. device=device,
  397. dtype=dtype,
  398. )
  399. )
  400. def init_non_persistent_buffers(self) -> None:
  401. """Initialize non-persistent buffers."""
  402. self._init_buffers()
  403. def generate_lookup_tensor(
  404. length: int,
  405. max_relative_position: Optional[int] = None,
  406. device=None,
  407. dtype=None,
  408. ):
  409. """Generate a one_hot lookup tensor to reindex embeddings along one dimension.
  410. Args:
  411. length: the length to reindex to.
  412. max_relative_position: the maximum relative position to consider.
  413. Relative position embeddings for distances above this threshold
  414. are zeroed out.
  415. Returns:
  416. a lookup Tensor of size [length, length, vocab_size] that satisfies
  417. ret[n,m,v] = 1{m - n + max_relative_position = v}.
  418. """
  419. if max_relative_position is None:
  420. max_relative_position = length - 1
  421. # Return the cached lookup tensor, otherwise compute it and cache it.
  422. vocab_size = 2 * max_relative_position + 1
  423. ret = torch.zeros(length, length, vocab_size, device=device, dtype=dtype)
  424. for i in range(length):
  425. for x in range(length):
  426. v = x - i + max_relative_position
  427. if abs(x - i) > max_relative_position:
  428. continue
  429. ret[i, x, v] = 1
  430. return ret
  431. def reindex_2d_einsum_lookup(
  432. relative_position_tensor,
  433. height: int,
  434. width: int,
  435. height_lookup: torch.Tensor,
  436. width_lookup: torch.Tensor,
  437. ) -> torch.Tensor:
  438. """Reindex 2d relative position bias with 2 independent einsum lookups.
  439. Adapted from:
  440. https://github.com/google-research/maxvit/blob/2e06a7f1f70c76e64cd3dabe5cd1b8c1a23c9fb7/maxvit/models/attention_utils.py
  441. Args:
  442. relative_position_tensor: tensor of shape
  443. [..., vocab_height, vocab_width, ...].
  444. height: height to reindex to.
  445. width: width to reindex to.
  446. height_lookup: one-hot height lookup
  447. width_lookup: one-hot width lookup
  448. Returns:
  449. reindexed_tensor: a Tensor of shape
  450. [..., height * width, height * width, ...]
  451. """
  452. reindexed_tensor = torch.einsum('nhw,ixh->nixw', relative_position_tensor, height_lookup)
  453. reindexed_tensor = torch.einsum('nixw,jyw->nijxy', reindexed_tensor, width_lookup)
  454. area = height * width
  455. return reindexed_tensor.reshape(relative_position_tensor.shape[0], area, area)
  456. class RelPosBiasTf(nn.Module):
  457. """ Relative Position Bias Impl (Compatible with Tensorflow MaxViT models)
  458. Adapted from:
  459. https://github.com/google-research/maxvit/blob/2e06a7f1f70c76e64cd3dabe5cd1b8c1a23c9fb7/maxvit/models/attention_utils.py
  460. """
  461. def __init__(
  462. self,
  463. window_size: Tuple[int, int],
  464. num_heads: int,
  465. prefix_tokens: int = 0,
  466. device=None,
  467. dtype=None,
  468. ):
  469. dd = {'device': device, 'dtype': dtype}
  470. super().__init__()
  471. assert prefix_tokens <= 1
  472. self.window_size = window_size
  473. self.window_area = window_size[0] * window_size[1]
  474. self.num_heads = num_heads
  475. vocab_height = 2 * window_size[0] - 1
  476. vocab_width = 2 * window_size[1] - 1
  477. self.bias_shape = (self.num_heads, vocab_height, vocab_width)
  478. self.relative_position_bias_table = nn.Parameter(torch.empty(self.bias_shape, **dd))
  479. height_lookup_shape = (window_size[0], window_size[0], vocab_height)
  480. width_lookup_shape = (window_size[1], window_size[1], vocab_width)
  481. self.register_buffer('height_lookup', torch.empty(height_lookup_shape, **dd), persistent=False)
  482. self.register_buffer('width_lookup', torch.empty(width_lookup_shape, **dd), persistent=False)
  483. # TODO: skip init when on meta device when safe to do so
  484. self.reset_parameters()
  485. def reset_parameters(self) -> None:
  486. """Initialize parameters and buffers."""
  487. nn.init.normal_(self.relative_position_bias_table, std=.02)
  488. self._init_buffers()
  489. def _init_buffers(self) -> None:
  490. """Compute and fill non-persistent buffer values."""
  491. device = self.height_lookup.device
  492. dtype = self.height_lookup.dtype
  493. self.height_lookup.copy_(generate_lookup_tensor(self.window_size[0], device=device, dtype=dtype))
  494. self.width_lookup.copy_(generate_lookup_tensor(self.window_size[1], device=device, dtype=dtype))
  495. def get_bias(self) -> torch.Tensor:
  496. # FIXME change to not use one-hot/einsum?
  497. return reindex_2d_einsum_lookup(
  498. self.relative_position_bias_table,
  499. self.window_size[0],
  500. self.window_size[1],
  501. self.height_lookup,
  502. self.width_lookup
  503. )
  504. def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
  505. return attn + self.get_bias()
  506. def init_non_persistent_buffers(self) -> None:
  507. """Initialize non-persistent buffers."""
  508. self._init_buffers()