classification.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. """Module containing utilities for classification."""
  18. import torch
  19. from torch import nn
  20. class ClassificationHead(nn.Module):
  21. """Module to be used as a classification head.
  22. Args:
  23. embed_size: the logits tensor coming from the networks.
  24. num_classes: an integer representing the numbers of classes to classify.
  25. Example:
  26. >>> feat = torch.rand(1, 256, 256)
  27. >>> head = ClassificationHead(256, 10)
  28. >>> head(feat).shape
  29. torch.Size([1, 10])
  30. """
  31. def __init__(self, embed_size: int = 768, num_classes: int = 10) -> None:
  32. super().__init__()
  33. self.norm = nn.LayerNorm(embed_size)
  34. self.linear = nn.Linear(embed_size, num_classes)
  35. def forward(self, x: torch.Tensor) -> torch.Tensor:
  36. out = x.mean(-2)
  37. return self.linear(self.norm(out))