Skip to content

Instantly share code, notes, and snippets.

@aurotripathy
Last active March 20, 2023 07:30
Show Gist options
  • Save aurotripathy/67e41361d67bf2e7122334862d17e47b to your computer and use it in GitHub Desktop.
Save aurotripathy/67e41361d67bf2e7122334862d17e47b to your computer and use it in GitHub Desktop.
""" Demonstrates the easy of integration of a custom layer """
import math
import torch
import torch.nn as nn
import numpy as np
class MyLinearLayer(nn.Module):
""" Custom Linear layer but mimics a standard linear layer """
def __init__(self, size_in, size_out):
super().__init__()
self.size_in, self.size_out = size_in, size_out
weights = torch.Tensor(size_out, size_in)
self.weights = nn.Parameter(weights) # nn.Parameter is a Tensor that's a module parameter.
bias = torch.Tensor(size_out)
self.bias = nn.Parameter(bias)
# initialize weights and biases
nn.init.kaiming_uniform_(self.weights, a=math.sqrt(5)) # weight init
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weights)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.bias, -bound, bound) # bias init
def forward(self, x):
w_times_x= torch.mm(x, self.weights.t())
return torch.add(w_times_x, self.bias) # w times x + b
class BasicModel(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 128, 3)
# self.linear = nn.Linear(256, 2)
self.linear = MyLinearLayer(256, 2)
def forward(self, x):
x = self. conv(x)
x = x.view(-1, 256)
return self.linear(x)
torch.manual_seed(0) # for repeatable results
basic_model = BasicModel()
inp = np.array([[[[1,2,3,4], # batch(=1) x channels(=1) x height x width
[1,2,3,4],
[1,2,3,4]]]])
x = torch.tensor(inp, dtype=torch.float)
print('Forward computation thru model:', basic_model(x))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment