pool1d.py 1.1 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. import torch
  2. def global_pool_nlc(
  3. x: torch.Tensor,
  4. pool_type: str = 'token',
  5. num_prefix_tokens: int = 1,
  6. reduce_include_prefix: bool = False,
  7. ):
  8. """Apply global pooling to tensor in NLC format.
  9. Args:
  10. x: Input tensor in (batch, length, channels) format.
  11. pool_type: Pooling type - 'token', 'avg', 'max', 'avgmax', or empty string.
  12. num_prefix_tokens: Number of prefix tokens (e.g., class token) to exclude from pooling.
  13. reduce_include_prefix: Whether to include prefix tokens in reduction.
  14. Returns:
  15. Pooled tensor.
  16. """
  17. if not pool_type:
  18. return x
  19. if pool_type == 'token':
  20. x = x[:, 0] # class token
  21. else:
  22. x = x if reduce_include_prefix else x[:, num_prefix_tokens:]
  23. if pool_type == 'avg':
  24. x = x.mean(dim=1)
  25. elif pool_type == 'avgmax':
  26. x = 0.5 * (x.amax(dim=1) + x.mean(dim=1))
  27. elif pool_type == 'max':
  28. x = x.amax(dim=1)
  29. else:
  30. assert not pool_type, f'Unknown pool type {pool_type}'
  31. return x