Skip to content

Instantly share code, notes, and snippets.

@Eeman1113
Created May 22, 2024 20:46
Show Gist options
  • Save Eeman1113/3b9d56e0974b3bdde3c6301359188500 to your computer and use it in GitHub Desktop.
Save Eeman1113/3b9d56e0974b3bdde3c6301359188500 to your computer and use it in GitHub Desktop.
from ..module import Module
from ..parameter import Parameter
class Linear(Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.weight = Parameter(shape=[self.output_dim, self.input_dim])
self.bias = Parameter(shape=[self.output_dim, 1])
def forward(self, x):
z = self.weight @ x + self.bias
return z
def inner_repr(self):
return f"input_dim={self.input_dim}, output_dim={self.output_dim}, " \
f"bias={True if self.bias is not None else False}"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment