ml_decoder.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. from typing import Optional
  2. import torch
  3. from torch import nn
  4. from torch import nn, Tensor
  5. from torch.nn.modules.transformer import _get_activation_fn
  6. def add_ml_decoder_head(model):
  7. if hasattr(model, 'global_pool') and hasattr(model, 'fc'): # most CNN models, like Resnet50
  8. model.global_pool = nn.Identity()
  9. del model.fc
  10. num_classes = model.num_classes
  11. num_features = model.num_features
  12. model.fc = MLDecoder(num_classes=num_classes, initial_num_features=num_features)
  13. elif hasattr(model, 'global_pool') and hasattr(model, 'classifier'): # EfficientNet
  14. model.global_pool = nn.Identity()
  15. del model.classifier
  16. num_classes = model.num_classes
  17. num_features = model.num_features
  18. model.classifier = MLDecoder(num_classes=num_classes, initial_num_features=num_features)
  19. elif 'RegNet' in model._get_name() or 'TResNet' in model._get_name(): # hasattr(model, 'head')
  20. del model.head
  21. num_classes = model.num_classes
  22. num_features = model.num_features
  23. model.head = MLDecoder(num_classes=num_classes, initial_num_features=num_features)
  24. else:
  25. print("Model code-writing is not aligned currently with ml-decoder")
  26. exit(-1)
  27. if hasattr(model, 'drop_rate'): # Ml-Decoder has inner dropout
  28. model.drop_rate = 0
  29. return model
  30. class TransformerDecoderLayerOptimal(nn.Module):
  31. def __init__(self, d_model, nhead=8, dim_feedforward=2048, dropout=0.1, activation="relu",
  32. layer_norm_eps=1e-5) -> None:
  33. super().__init__()
  34. self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
  35. self.dropout = nn.Dropout(dropout)
  36. self.dropout1 = nn.Dropout(dropout)
  37. self.dropout2 = nn.Dropout(dropout)
  38. self.dropout3 = nn.Dropout(dropout)
  39. self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
  40. # Implementation of Feedforward model
  41. self.linear1 = nn.Linear(d_model, dim_feedforward)
  42. self.linear2 = nn.Linear(dim_feedforward, d_model)
  43. self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
  44. self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps)
  45. self.activation = _get_activation_fn(activation)
  46. def __setstate__(self, state):
  47. if 'activation' not in state:
  48. state['activation'] = torch.nn.functional.relu
  49. super(TransformerDecoderLayerOptimal, self).__setstate__(state)
  50. def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
  51. memory_mask: Optional[Tensor] = None,
  52. tgt_key_padding_mask: Optional[Tensor] = None,
  53. memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
  54. tgt = tgt + self.dropout1(tgt)
  55. tgt = self.norm1(tgt)
  56. tgt2 = self.multihead_attn(tgt, memory, memory)[0]
  57. tgt = tgt + self.dropout2(tgt2)
  58. tgt = self.norm2(tgt)
  59. tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
  60. tgt = tgt + self.dropout3(tgt2)
  61. tgt = self.norm3(tgt)
  62. return tgt
  63. # class ExtrapClasses(object):
  64. # def __init__(self, num_queries: int, group_size: int):
  65. # self.num_queries = num_queries
  66. # self.group_size = group_size
  67. #
  68. # def __call__(self, h: torch.Tensor, class_embed_w: torch.Tensor, class_embed_b: torch.Tensor, out_extrap:
  69. # torch.Tensor):
  70. # # h = h.unsqueeze(-1).expand(-1, -1, -1, self.group_size)
  71. # h = h[..., None].repeat(1, 1, 1, self.group_size) # torch.Size([bs, 5, 768, groups])
  72. # w = class_embed_w.view((self.num_queries, h.shape[2], self.group_size))
  73. # out = (h * w).sum(dim=2) + class_embed_b
  74. # out = out.view((h.shape[0], self.group_size * self.num_queries))
  75. # return out
  76. class MLDecoder(nn.Module):
  77. def __init__(self, num_classes, num_of_groups=-1, decoder_embedding=768, initial_num_features=2048):
  78. super().__init__()
  79. embed_len_decoder = 100 if num_of_groups < 0 else num_of_groups
  80. if embed_len_decoder > num_classes:
  81. embed_len_decoder = num_classes
  82. self.embed_len_decoder = embed_len_decoder
  83. # switching to 768 initial embeddings
  84. decoder_embedding = 768 if decoder_embedding < 0 else decoder_embedding
  85. self.embed_standart = nn.Linear(initial_num_features, decoder_embedding)
  86. # decoder
  87. decoder_dropout = 0.1
  88. num_layers_decoder = 1
  89. dim_feedforward = 2048
  90. layer_decode = TransformerDecoderLayerOptimal(d_model=decoder_embedding,
  91. dim_feedforward=dim_feedforward, dropout=decoder_dropout)
  92. self.decoder = nn.TransformerDecoder(layer_decode, num_layers=num_layers_decoder)
  93. # non-learnable queries
  94. self.query_embed = nn.Embedding(embed_len_decoder, decoder_embedding)
  95. self.query_embed.requires_grad_(False)
  96. # group fully-connected
  97. self.num_classes = num_classes
  98. self.duplicate_factor = int(num_classes / embed_len_decoder + 0.999)
  99. self.duplicate_pooling = torch.nn.Parameter(
  100. torch.Tensor(embed_len_decoder, decoder_embedding, self.duplicate_factor))
  101. self.duplicate_pooling_bias = torch.nn.Parameter(torch.Tensor(num_classes))
  102. torch.nn.init.xavier_normal_(self.duplicate_pooling)
  103. torch.nn.init.constant_(self.duplicate_pooling_bias, 0)
  104. def forward(self, x):
  105. if len(x.shape) == 4: # [bs,2048, 7,7]
  106. embedding_spatial = x.flatten(2).transpose(1, 2)
  107. else: # [bs, 197,468]
  108. embedding_spatial = x
  109. embedding_spatial_786 = self.embed_standart(embedding_spatial)
  110. embedding_spatial_786 = torch.nn.functional.relu(embedding_spatial_786, inplace=True)
  111. bs = embedding_spatial_786.shape[0]
  112. query_embed = self.query_embed.weight
  113. # tgt = query_embed.unsqueeze(1).repeat(1, bs, 1)
  114. tgt = query_embed.unsqueeze(1).expand(-1, bs, -1) # no allocation of memory with expand
  115. h = self.decoder(tgt, embedding_spatial_786.transpose(0, 1)) # [embed_len_decoder, batch, 768]
  116. h = h.transpose(0, 1)
  117. out_extrap = torch.zeros(h.shape[0], h.shape[1], self.duplicate_factor, device=h.device, dtype=h.dtype)
  118. for i in range(self.embed_len_decoder): # group FC
  119. h_i = h[:, i, :]
  120. w_i = self.duplicate_pooling[i, :, :]
  121. out_extrap[:, i, :] = torch.matmul(h_i, w_i)
  122. h_out = out_extrap.flatten(1)[:, :self.num_classes]
  123. h_out += self.duplicate_pooling_bias
  124. logits = h_out
  125. return logits