Skip to content

Instantly share code, notes, and snippets.

@makingglitches
Created September 25, 2021 12:18
Show Gist options
  • Save makingglitches/442836ba69c10a815a8717818cb51d37 to your computer and use it in GitHub Desktop.
Save makingglitches/442836ba69c10a815a8717818cb51d37 to your computer and use it in GitHub Desktop.
additionally the training part
def trainshit(model:NeuralNetwork, samples:list[tuple[torch.Tensor, torch.Tensor]], epochs=100):
criterion = nn.BCELoss().double()
optimizer = torch.optim.SGD(model.parameters(),lr=0.01, momentum=0.9)
print()
for epoch in range(epochs):
i = 0
while i < len(samples):
print(f'\repoch: {epoch} input: {i} ',end=' ')
input = samples[i][0]
target = samples[i][1]
# enumerate mini batches
# clear the gradients
optimizer.zero_grad()
# compute the model output
model.double()
yhat = model(input)
# calculate loss
loss = criterion(yhat, target)
# credit assignment
loss.backward()
# update model weights
optimizer.step()
i+=1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment