Skip to content

Instantly share code, notes, and snippets.

@subhadarship
Last active May 10, 2024 08:59
Show Gist options
  • Save subhadarship/54f6d320ba34afe80766ca89a1ceb448 to your computer and use it in GitHub Desktop.
Save subhadarship/54f6d320ba34afe80766ca89a1ceb448 to your computer and use it in GitHub Desktop.
simple neural net using pytorch
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
class MyDataset(Dataset):
def __init__(self):
self.x = torch.randn(size=(10000, 64)) # 10,000 samples, 64 dims
self.y = torch.randint(low=0, high=2 + 1, size=(10000,)) # possible labels: {0, 1, 2}
def __len__(self):
return len(self.x)
def __getitem__(self, idx):
return self.x[idx], self.y[idx]
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(64, 128)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(128, 128)
self.relu2 = nn.ReLU()
self.fc3 = nn.Linear(128, 3)
def forward(self, x):
x = self.relu1(self.fc1(x))
x = self.relu2(self.fc2(x))
x = self.fc3(x)
return x
# xavier initialization
def initialize_weights(m):
if hasattr(m, 'weight') and m.weight.dim() > 1:
tqdm.write(f'using xavier initialization on weights of: {str(m)}')
nn.init.xavier_uniform_(m.weight.data)
# define train method
def train(_model, _dataloader, _criterion, _optimizer, _epoch, _device):
# set model to training mode
_model.train()
epoch_loss = 0
num_correct = 0
total = 0
tqdm_meter = tqdm(_dataloader, unit=' batches', desc=f'Epoch {_epoch}', leave=False)
for _data, _label in tqdm_meter:
# transfer to device
_data, _label = _data.to(_device), _label.to(_device)
# zero out optimizer
_optimizer.zero_grad()
# forward pass
out = _model(_data)
# compute loss
_loss = _criterion(out, _label)
# update epoch_loss
epoch_loss = epoch_loss + _loss.item()
# compute pred
_, pred = out.max(dim=1)
# compute number of correct predictions in batch
num_correct_batch = (pred == _label).sum().item()
# update number of correct predictions so far
num_correct = num_correct + num_correct_batch
# update total samples so far
total = total + _data.shape[0]
# compute gradients
_loss.backward()
# compute grad norm
_grad_norm = nn.utils.clip_grad_norm_(
parameters=_model.parameters(),
max_norm=float('inf') # set max_norm to inf for no gradient clipping
)
# optimizer step
_optimizer.step()
# update tqdm meter
tqdm_meter.set_postfix(ordered_dict={
'loss': f'{_loss.item():0.4f}', 'grad norm': _grad_norm
})
tqdm_meter.update()
return epoch_loss, num_correct / total
if __name__ == "__main__":
# set random seed
seed = 123
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# set device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# data set
data = MyDataset()
# data loader
data_loader = DataLoader(dataset=data, batch_size=64, shuffle=True)
# define model
model = MyModel()
# transfer model to device
model = model.to(device)
# xavier initialization
model.apply(initialize_weights)
# print model
tqdm.write(str(model))
# number of trainable params
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
tqdm.write(f'number of trainable parameters: {num_params:,}')
# define criterion
criterion = nn.CrossEntropyLoss()
# define optimizer
optimizer = torch.optim.Adam(model.parameters())
# train
losses, accs = [], []
for epoch in range(1, 10 + 1):
loss, acc = train(model, data_loader, criterion, optimizer, epoch, device)
losses.append(loss)
accs.append(acc)
# plot
with plt.xkcd():
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10, 4), dpi=160)
axs[0].plot(range(1, 10 + 1), losses, marker='o', color='xkcd:violet')
axs[0].set_ylabel('Loss', fontsize=18)
axs[0].set_xlabel('Epoch', fontsize=18)
# axs[0].grid(alpha=0.4, linestyle='--') # grid does not work with plt.xkcd
axs[1].plot(range(1, 10 + 1), accs, marker='*', color='xkcd:red orange')
axs[1].set_ylabel('Accuracy', fontsize=18)
axs[1].set_xlabel('Epoch', fontsize=18)
# axs[1].grid(alpha=0.4, linestyle='--') # grid does not work with plt.xkcd
plt.tight_layout()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment