__init__.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from romatch.utils.utils import get_grid
  5. from .layers.block import Block
  6. from .layers.attention import MemEffAttention
  7. from .dinov2 import vit_large
  8. class TransformerDecoder(nn.Module):
  9. def __init__(self, blocks, hidden_dim, out_dim, is_classifier = False, *args,
  10. amp = False, pos_enc = True, learned_embeddings = False, embedding_dim = None, amp_dtype = torch.float16, **kwargs) -> None:
  11. super().__init__(*args, **kwargs)
  12. self.blocks = blocks
  13. self.to_out = nn.Linear(hidden_dim, out_dim)
  14. self.hidden_dim = hidden_dim
  15. self.out_dim = out_dim
  16. self._scales = [16]
  17. self.is_classifier = is_classifier
  18. self.amp = amp
  19. self.amp_dtype = amp_dtype
  20. self.pos_enc = pos_enc
  21. self.learned_embeddings = learned_embeddings
  22. if self.learned_embeddings:
  23. self.learned_pos_embeddings = nn.Parameter(nn.init.kaiming_normal_(torch.empty((1, hidden_dim, embedding_dim, embedding_dim))))
  24. def scales(self):
  25. return self._scales.copy()
  26. def forward(self, gp_posterior, features, old_stuff, new_scale):
  27. with torch.autocast("cuda", dtype=self.amp_dtype, enabled=self.amp):
  28. B,C,H,W = gp_posterior.shape
  29. x = torch.cat((gp_posterior, features), dim = 1)
  30. B,C,H,W = x.shape
  31. grid = get_grid(B, H, W, x.device).reshape(B,H*W,2)
  32. if self.learned_embeddings:
  33. pos_enc = F.interpolate(self.learned_pos_embeddings, size = (H,W), mode = 'bilinear', align_corners = False).permute(0,2,3,1).reshape(1,H*W,C)
  34. else:
  35. pos_enc = 0
  36. tokens = x.reshape(B,C,H*W).permute(0,2,1) + pos_enc
  37. z = self.blocks(tokens)
  38. out = self.to_out(z)
  39. out = out.permute(0,2,1).reshape(B, self.out_dim, H, W)
  40. warp, certainty = out[:, :-1], out[:, -1:]
  41. return warp, certainty, None