Skip to content

Instantly share code, notes, and snippets.

@Andrew9255
Created June 20, 2020 15:03
Show Gist options
  • Save Andrew9255/efc4b0a4eda63ccbc5e37c8b6a59d0b9 to your computer and use it in GitHub Desktop.
Save Andrew9255/efc4b0a4eda63ccbc5e37c8b6a59d0b9 to your computer and use it in GitHub Desktop.
#basic libary
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
import pandas as pd
##Define Function and class to be used
prox_plus = nn.Threshold(0,0) ## to make all output postive
class NMF1(nn.Module):## Model for task 1
def __init__(self, v, d):
super(NMF1, self).__init__()
self.A = nn.Parameter(torch.rand(v, d, requires_grad=True))
def forward(self):
# return (self.AA>0.5).float()
return prox_plus(torch.matmul(self.A, torch.transpose(self.A, 0, 1)))
## Task 1 Training
print('Start training on Task 1...')
#Set dimension parpemeter and d
v = 500
d = 50
task1 = NMF1(v, d)
n_epoch =500
loss_fn = nn.MSELoss(reduction='sum')
task1loss=[] #collect loss
optimizer = optim.SGD(task1.parameters(), lr=0.00001)
for epoch in range(n_epoch):
Y_ = task1()
loss = loss_fn(Y_, gratorch)
task1.zero_grad() # need to clear the old gradients
loss.backward()
optimizer.step()
# task1loss.append(loss)
if(epoch%10==0):
task1loss.append(loss)
# print(loss)
print('Learning curve for Task 1')
plt.plot(task1loss[1:])
plt.ylabel('loss over time')
plt.xlabel('iteration x 10')
plt.show()
print('Final loss on Task 1: ')
print(task1loss[-1])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment