Skip to content

Instantly share code, notes, and snippets.

@TMPxyz
Created April 24, 2022 08:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save TMPxyz/7dd1a847027e36046f3758a257440ad2 to your computer and use it in GitHub Desktop.
Save TMPxyz/7dd1a847027e36046f3758a257440ad2 to your computer and use it in GitHub Desktop.
simple classifier
# %%
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as td
# %% [markdown]
# ### Gen data
# %%
INPUT_NUM = 3
OUTPUT_NUM = 8
# %%
def calc_idx(a,b,c):
return 4*(a>=0.5) + 2*(b>=0.5) + 1*(c>=0.5) # accuracy 98%
# return 4* ( 0.25 <= a <= 0.75 ) + 2*( b >= 0.5 ) + 1*( c >= 0.5 ) # accuracy 50%
# return 4* ((a%0.5) > 0.25) + 2*( (b%0.5) > 0.25 ) + 1*( (c%0.5) > 0.25 ) # accuracy 25%
# %%
x = np.random.random(size=(1000, 3))
y = []
for a,b,c in x:
idx = calc_idx(a,b,c)
# lst = [0.] * 8
# lst[idx] = 1.0
# y.append(lst) # mse Loss
y.append(idx) # entropy Loss
split = int(0.8 * len(x))
otrain_x, otrain_y = x[:split], y[:split]
otest_x, otest_y = x[split:], y[split:]
display(x[:5], y[:5])
# %% [markdown]
# ### Dataset
# %%
train_x = torch.tensor(otrain_x, dtype=torch.float)
train_y = torch.tensor(otrain_y, dtype=torch.long)
train_ds = td.TensorDataset(train_x, train_y)
train_ld = td.DataLoader(train_ds, batch_size=8)
test_x = torch.tensor(otest_x, dtype=torch.float)
test_y = torch.tensor(otest_y, dtype=torch.long)
test_ds = td.TensorDataset(test_x, test_y)
test_ld = td.DataLoader(test_ds, batch_size=8)
# %% [markdown]
# ### NN
# %%
hl = 15
class Net(nn.Module):
def __init__(self, input_num, output_num):
super().__init__()
self._input = nn.Linear(input_num, hl)
self._fc1 = nn.Linear(hl, hl)
self._fc2 = nn.Linear(hl, output_num)
def forward(self, inputs):
x = inputs
x = torch.relu( self._input(x) ) # using relu here is faster than using sigmoid
x = torch.relu( self._fc1(x) )
x = torch.sigmoid( self._fc2(x) ) # ReLU gives extremely poor results, use sigmoid
return x
model = Net(INPUT_NUM, OUTPUT_NUM)
print(model)
# %% [markdown]
# ### Train
# %%
def train(model, data_loader, optimizer):
model.train()
train_loss = 0
for batch, tensor in enumerate(data_loader):
data, target = tensor
# forward
optimizer.zero_grad()
out = model(data)
loss = loss_fn(out, target)
train_loss += loss.item()
# backward
loss.backward()
optimizer.step()
avg_loss = train_loss / (batch+1)
print(f"Training set, avg loss {avg_loss:.6f}")
return avg_loss
def test(model, data_loader):
model.eval()
test_loss = 0
correct = 0
batch_count = 0
with torch.no_grad():
for _, tensor in enumerate(data_loader, 1):
batch_count += 1
data, target = tensor
out = model(data)
test_loss += loss_fn(out, target).item()
predicted = torch.argmax(out, dim=-1)
# labels = torch.argmax(target, dim=-1) # if use MSELoss
labels = target # use CrossEntropyLoss
correct += torch.sum(predicted==labels).item()
avg_loss = test_loss / batch_count
accu = correct/len(data_loader.dataset)
print(f"avg_loss : {avg_loss}, accuracy = {accu}")
return avg_loss, accu
loss_fn = nn.CrossEntropyLoss()
EPOCHS = 200
epoch_nums = []
training_loss = []
validation_loss = []
accus = []
learning_rate = 1e-3
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) # RMSprop gives same accuracy as Adam(W)
optimizer.zero_grad()
for epoch in range(1, 1+EPOCHS):
print(f"epoch {epoch}")
train_loss = train(model, train_ld, optimizer)
test_loss, accuracy = test(model, test_ld)
epoch_nums.append(epoch)
training_loss.append(train_loss)
validation_loss.append(test_loss)
accus.append(accuracy)
# %%
import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams["figure.figsize"]=16,8
ax = plt.subplot(121)
ax.plot(epoch_nums, training_loss, label="train loss")
ax.plot(epoch_nums, validation_loss, label="validation loss")
plt.legend()
ax = plt.subplot(122)
ax.plot(epoch_nums, accus, label="accuracy")
plt.legend()
plt.show()
# %%
x = np.random.random(size=(100, 3))
y = torch.tensor( list(map(lambda x0: calc_idx(*x0), x)) )
model.eval()
with torch.no_grad():
output = model(torch.tensor(x, dtype=torch.float))
got = torch.tensor( list(map( torch.argmax, output )) )
print( torch.sum( y==got ).item(), "/", 100 )
torch.set_printoptions(sci_mode=False)
display(x[:5], output[:5])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment