Skip to content

Instantly share code, notes, and snippets.

@torridgristle
Created February 28, 2022 14:18
Show Gist options
  • Save torridgristle/71a83d83f2a5494202dc5bd987321b7e to your computer and use it in GitHub Desktop.
Save torridgristle/71a83d83f2a5494202dc5bd987321b7e to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, *args, **kwargs):
return self.fn(x, *args, **kwargs) + x
class Concat(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, *args, **kwargs):
return torch.cat([x,self.fn(x, *args, **kwargs)],1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment