This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
fn = FirstNetwork_v1() | |
fit_v1() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def fit_v1(epochs = 1000, learning_rate = 1): | |
loss_arr = [] | |
acc_arr = [] | |
opt = optim.SGD(fn.parameters(), lr=learning_rate) | |
for epoch in range(epochs): | |
y_hat = fn(X_train) | |
loss = F.cross_entropy(y_hat, Y_train) | |
loss_arr.append(loss.item()) | |
acc_arr.append(accuracy(y_hat, Y_train)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch import optim |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
fn = FirstNetwork_v1() | |
fit() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class FirstNetwork_v1(nn.Module): | |
def __init__(self): | |
super().__init__() | |
torch.manual_seed(0) | |
self.lin1 = nn.Linear(2, 2) | |
self.lin2 = nn.Linear(2, 4) | |
def forward(self, X): | |
a1 = self.lin1(X) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
fn = FirstNetwork() | |
fit() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def fit(epochs = 1000, learning_rate = 1): | |
loss_arr = [] | |
acc_arr = [] | |
for epoch in range(epochs): | |
y_hat = fn(X_train) | |
loss = F.cross_entropy(y_hat, Y_train) | |
loss_arr.append(loss.item()) | |
acc_arr.append(accuracy(y_hat, Y_train)) | |
loss.backward() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class FirstNetwork(nn.Module): | |
def __init__(self): | |
super().__init__() | |
torch.manual_seed(0) | |
self.weights1 = nn.Parameter(torch.randn(2, 2) / math.sqrt(2)) | |
self.bias1 = nn.Parameter(torch.zeros(2)) | |
self.weights2 = nn.Parameter(torch.randn(2, 4) / math.sqrt(2)) | |
self.bias2 = nn.Parameter(torch.zeros(4)) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.nn as nn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.nn.functional as F |