| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208 |
- """ Activations (memory-efficient w/ custom autograd)
- A collection of activations fn and modules with a common interface so that they can
- easily be swapped. All have an `inplace` arg even if not used.
- These activations are not compatible with jit scripting or ONNX export of the model, please use
- basic versions of the activations.
- Hacked together by / Copyright 2020 Ross Wightman
- """
- import torch
- from torch import nn as nn
- from torch.nn import functional as F
- def swish_fwd(x):
- return x.mul(torch.sigmoid(x))
- def swish_bwd(x, grad_output):
- x_sigmoid = torch.sigmoid(x)
- return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))
- class SwishAutoFn(torch.autograd.Function):
- """ optimised Swish w/ memory-efficient checkpoint
- Inspired by conversation btw Jeremy Howard & Adam Pazske
- https://twitter.com/jeremyphoward/status/1188251041835315200
- """
- @staticmethod
- def symbolic(g, x):
- return g.op("Mul", x, g.op("Sigmoid", x))
- @staticmethod
- def forward(ctx, x):
- ctx.save_for_backward(x)
- return swish_fwd(x)
- @staticmethod
- def backward(ctx, grad_output):
- x = ctx.saved_tensors[0]
- return swish_bwd(x, grad_output)
- def swish_me(x, inplace=False):
- return SwishAutoFn.apply(x)
- class SwishMe(nn.Module):
- def __init__(self, inplace: bool = False):
- super().__init__()
- def forward(self, x):
- return SwishAutoFn.apply(x)
- def mish_fwd(x):
- return x.mul(torch.tanh(F.softplus(x)))
- def mish_bwd(x, grad_output):
- x_sigmoid = torch.sigmoid(x)
- x_tanh_sp = F.softplus(x).tanh()
- return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))
- class MishAutoFn(torch.autograd.Function):
- """ Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
- A memory efficient variant of Mish
- """
- @staticmethod
- def forward(ctx, x):
- ctx.save_for_backward(x)
- return mish_fwd(x)
- @staticmethod
- def backward(ctx, grad_output):
- x = ctx.saved_tensors[0]
- return mish_bwd(x, grad_output)
- def mish_me(x, inplace=False):
- return MishAutoFn.apply(x)
- class MishMe(nn.Module):
- def __init__(self, inplace: bool = False):
- super().__init__()
- def forward(self, x):
- return MishAutoFn.apply(x)
- def hard_sigmoid_fwd(x, inplace: bool = False):
- return (x + 3).clamp(min=0, max=6).div(6.)
- def hard_sigmoid_bwd(x, grad_output):
- m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6.
- return grad_output * m
- class HardSigmoidAutoFn(torch.autograd.Function):
- @staticmethod
- def forward(ctx, x):
- ctx.save_for_backward(x)
- return hard_sigmoid_fwd(x)
- @staticmethod
- def backward(ctx, grad_output):
- x = ctx.saved_tensors[0]
- return hard_sigmoid_bwd(x, grad_output)
- def hard_sigmoid_me(x, inplace: bool = False):
- return HardSigmoidAutoFn.apply(x)
- class HardSigmoidMe(nn.Module):
- def __init__(self, inplace: bool = False):
- super().__init__()
- def forward(self, x):
- return HardSigmoidAutoFn.apply(x)
- def hard_swish_fwd(x):
- return x * (x + 3).clamp(min=0, max=6).div(6.)
- def hard_swish_bwd(x, grad_output):
- m = torch.ones_like(x) * (x >= 3.)
- m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m)
- return grad_output * m
- class HardSwishAutoFn(torch.autograd.Function):
- """A memory efficient HardSwish activation"""
- @staticmethod
- def forward(ctx, x):
- ctx.save_for_backward(x)
- return hard_swish_fwd(x)
- @staticmethod
- def backward(ctx, grad_output):
- x = ctx.saved_tensors[0]
- return hard_swish_bwd(x, grad_output)
- @staticmethod
- def symbolic(g, self):
- input = g.op("Add", self, g.op('Constant', value_t=torch.tensor(3, dtype=torch.float)))
- hardtanh_ = g.op("Clip", input, g.op('Constant', value_t=torch.tensor(0, dtype=torch.float)), g.op('Constant', value_t=torch.tensor(6, dtype=torch.float)))
- hardtanh_ = g.op("Div", hardtanh_, g.op('Constant', value_t=torch.tensor(6, dtype=torch.float)))
- return g.op("Mul", self, hardtanh_)
- def hard_swish_me(x, inplace=False):
- return HardSwishAutoFn.apply(x)
- class HardSwishMe(nn.Module):
- def __init__(self, inplace: bool = False):
- super().__init__()
- def forward(self, x):
- return HardSwishAutoFn.apply(x)
- def hard_mish_fwd(x):
- return 0.5 * x * (x + 2).clamp(min=0, max=2)
- def hard_mish_bwd(x, grad_output):
- m = torch.ones_like(x) * (x >= -2.)
- m = torch.where((x >= -2.) & (x <= 0.), x + 1., m)
- return grad_output * m
- class HardMishAutoFn(torch.autograd.Function):
- """ A memory efficient variant of Hard Mish
- Experimental, based on notes by Mish author Diganta Misra at
- https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
- """
- @staticmethod
- def forward(ctx, x):
- ctx.save_for_backward(x)
- return hard_mish_fwd(x)
- @staticmethod
- def backward(ctx, grad_output):
- x = ctx.saved_tensors[0]
- return hard_mish_bwd(x, grad_output)
- def hard_mish_me(x, inplace: bool = False):
- return HardMishAutoFn.apply(x)
- class HardMishMe(nn.Module):
- def __init__(self, inplace: bool = False):
- super().__init__()
- def forward(self, x):
- return HardMishAutoFn.apply(x)
|