| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465 |
- """ Binary Cross Entropy w/ a few extras
- Hacked together by / Copyright 2021 Ross Wightman
- """
- from typing import Optional, Union
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- class BinaryCrossEntropy(nn.Module):
- """ BCE with optional one-hot from dense targets, label smoothing, thresholding
- NOTE for experiments comparing CE to BCE /w label smoothing, may remove
- """
- def __init__(
- self,
- smoothing=0.1,
- target_threshold: Optional[float] = None,
- weight: Optional[torch.Tensor] = None,
- reduction: str = 'mean',
- sum_classes: bool = False,
- pos_weight: Optional[Union[torch.Tensor, float]] = None,
- ):
- super(BinaryCrossEntropy, self).__init__()
- assert 0. <= smoothing < 1.0
- if pos_weight is not None:
- if not isinstance(pos_weight, torch.Tensor):
- pos_weight = torch.tensor(pos_weight)
- self.smoothing = smoothing
- self.target_threshold = target_threshold
- self.reduction = 'none' if sum_classes else reduction
- self.sum_classes = sum_classes
- self.register_buffer('weight', weight)
- self.register_buffer('pos_weight', pos_weight)
- def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
- batch_size = x.shape[0]
- assert batch_size == target.shape[0]
- if target.shape != x.shape:
- # NOTE currently assume smoothing or other label softening is applied upstream if targets are already sparse
- num_classes = x.shape[-1]
- # FIXME should off/on be different for smoothing w/ BCE? Other impl out there differ
- off_value = self.smoothing / num_classes
- on_value = 1. - self.smoothing + off_value
- target = target.long().view(-1, 1)
- target = torch.full(
- (batch_size, num_classes),
- off_value,
- device=x.device, dtype=x.dtype).scatter_(1, target, on_value)
- if self.target_threshold is not None:
- # Make target 0, or 1 if threshold set
- target = target.gt(self.target_threshold).to(dtype=target.dtype)
- loss = F.binary_cross_entropy_with_logits(
- x, target,
- self.weight,
- pos_weight=self.pos_weight,
- reduction=self.reduction,
- )
- if self.sum_classes:
- loss = loss.sum(-1).mean()
- return loss
|