Skip to content

Instantly share code, notes, and snippets.

@aakashns
Last active July 15, 2018 16:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save aakashns/a07f9c457c48f6af06f9e9245b1def9b to your computer and use it in GitHub Desktop.
Save aakashns/a07f9c457c48f6af06f9e9245b1def9b to your computer and use it in GitHub Desktop.
import torch.nn as nn
import torch.nn.functional as F
def conv_2d(ni, nf, stride=1, ks=3):
return nn.Conv2d(in_channels=ni, out_channels=nf,
kernel_size=ks, stride=stride,
padding=ks//2, bias=False)
def bn_relu_conv(ni, nf):
return nn.Sequential(nn.BatchNorm2d(ni),
nn.ReLU(inplace=True),
conv_2d(ni, nf))
class BasicBlock(nn.Module):
def __init__(self, ni, nf, stride=1):
super().__init__()
self.bn = nn.BatchNorm2d(ni)
self.conv1 = conv_2d(ni, nf, stride)
self.conv2 = bn_relu_conv(nf, nf)
self.shortcut = lambda x: x
if ni != nf:
self.shortcut = conv_2d(ni, nf, stride, 1)
def forward(self, x):
x = F.relu(self.bn(x), inplace=True)
r = self.shortcut(x)
x = self.conv1(x)
x = self.conv2(x) * 0.2
return x.add_(r)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment