Created
April 25, 2024 06:51
-
-
Save nqgl/4e47fe4d89c8a8dea9ee43693c391a68 to your computer and use it in GitHub Desktop.
ProLU PyTorch Implementation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.cuda.amp import custom_bwd, custom_fwd | |
class ProLU(torch.autograd.Function): | |
STE: torch.autograd.Function | |
ReLU: torch.autograd.Function | |
@staticmethod | |
@custom_fwd | |
def forward(ctx, m, b): | |
gate = (m + b > 0) & (m > 0) | |
ctx.save_for_backward(m, gate) | |
return torch.where(gate, m, 0) | |
@staticmethod | |
@custom_bwd | |
def backward(ctx, grad_output): | |
raise NotImplementedError( | |
"This method should be overridden by a subclass of ProLU to provide a backward implementation." | |
) | |
class ProLU_ReLU(ProLU): | |
@staticmethod | |
@custom_bwd | |
def backward(ctx, grad_output): | |
m, gate = ctx.saved_tensors | |
gated_grad = torch.where(gate, grad_output, 0) | |
grad_m, grad_b = gated_grad.clone(), gated_grad.clone() | |
return grad_m, grad_b, None | |
class ProLU_STE(ProLU): | |
@staticmethod | |
@custom_bwd | |
def backward(ctx, grad_output): | |
m, gate = ctx.saved_tensors | |
gated_grad = torch.where(gate, grad_output, 0) | |
grad_b = gated_grad * m | |
grad_m = gated_grad + grad_b.clone() | |
return grad_m, grad_b, None | |
ProLU.STE = ProLU_STE | |
ProLU.ReLU = ProLU_ReLU | |
def prolu_ste(m, b): | |
return ProLU_STE.apply(m, b) | |
def prolu_relu(m, b): | |
return ProLU_ReLU.apply(m, b) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment