Skip to content

Instantly share code, notes, and snippets.

@nqgl
Created April 25, 2024 06:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nqgl/4e47fe4d89c8a8dea9ee43693c391a68 to your computer and use it in GitHub Desktop.
Save nqgl/4e47fe4d89c8a8dea9ee43693c391a68 to your computer and use it in GitHub Desktop.
ProLU PyTorch Implementation
import torch
from torch.cuda.amp import custom_bwd, custom_fwd
class ProLU(torch.autograd.Function):
STE: torch.autograd.Function
ReLU: torch.autograd.Function
@staticmethod
@custom_fwd
def forward(ctx, m, b):
gate = (m + b > 0) & (m > 0)
ctx.save_for_backward(m, gate)
return torch.where(gate, m, 0)
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
raise NotImplementedError(
"This method should be overridden by a subclass of ProLU to provide a backward implementation."
)
class ProLU_ReLU(ProLU):
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
m, gate = ctx.saved_tensors
gated_grad = torch.where(gate, grad_output, 0)
grad_m, grad_b = gated_grad.clone(), gated_grad.clone()
return grad_m, grad_b, None
class ProLU_STE(ProLU):
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
m, gate = ctx.saved_tensors
gated_grad = torch.where(gate, grad_output, 0)
grad_b = gated_grad * m
grad_m = gated_grad + grad_b.clone()
return grad_m, grad_b, None
ProLU.STE = ProLU_STE
ProLU.ReLU = ProLU_ReLU
def prolu_ste(m, b):
return ProLU_STE.apply(m, b)
def prolu_relu(m, b):
return ProLU_ReLU.apply(m, b)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment