Last active
May 10, 2024 08:59
-
-
Save subhadarship/54f6d320ba34afe80766ca89a1ceb448 to your computer and use it in GitHub Desktop.
simple neural net using pytorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
from torch.utils.data import Dataset, DataLoader | |
from tqdm import tqdm | |
import matplotlib.pyplot as plt | |
class MyDataset(Dataset): | |
def __init__(self): | |
self.x = torch.randn(size=(10000, 64)) # 10,000 samples, 64 dims | |
self.y = torch.randint(low=0, high=2 + 1, size=(10000,)) # possible labels: {0, 1, 2} | |
def __len__(self): | |
return len(self.x) | |
def __getitem__(self, idx): | |
return self.x[idx], self.y[idx] | |
class MyModel(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.fc1 = nn.Linear(64, 128) | |
self.relu1 = nn.ReLU() | |
self.fc2 = nn.Linear(128, 128) | |
self.relu2 = nn.ReLU() | |
self.fc3 = nn.Linear(128, 3) | |
def forward(self, x): | |
x = self.relu1(self.fc1(x)) | |
x = self.relu2(self.fc2(x)) | |
x = self.fc3(x) | |
return x | |
# xavier initialization | |
def initialize_weights(m): | |
if hasattr(m, 'weight') and m.weight.dim() > 1: | |
tqdm.write(f'using xavier initialization on weights of: {str(m)}') | |
nn.init.xavier_uniform_(m.weight.data) | |
# define train method | |
def train(_model, _dataloader, _criterion, _optimizer, _epoch, _device): | |
# set model to training mode | |
_model.train() | |
epoch_loss = 0 | |
num_correct = 0 | |
total = 0 | |
tqdm_meter = tqdm(_dataloader, unit=' batches', desc=f'Epoch {_epoch}', leave=False) | |
for _data, _label in tqdm_meter: | |
# transfer to device | |
_data, _label = _data.to(_device), _label.to(_device) | |
# zero out optimizer | |
_optimizer.zero_grad() | |
# forward pass | |
out = _model(_data) | |
# compute loss | |
_loss = _criterion(out, _label) | |
# update epoch_loss | |
epoch_loss = epoch_loss + _loss.item() | |
# compute pred | |
_, pred = out.max(dim=1) | |
# compute number of correct predictions in batch | |
num_correct_batch = (pred == _label).sum().item() | |
# update number of correct predictions so far | |
num_correct = num_correct + num_correct_batch | |
# update total samples so far | |
total = total + _data.shape[0] | |
# compute gradients | |
_loss.backward() | |
# compute grad norm | |
_grad_norm = nn.utils.clip_grad_norm_( | |
parameters=_model.parameters(), | |
max_norm=float('inf') # set max_norm to inf for no gradient clipping | |
) | |
# optimizer step | |
_optimizer.step() | |
# update tqdm meter | |
tqdm_meter.set_postfix(ordered_dict={ | |
'loss': f'{_loss.item():0.4f}', 'grad norm': _grad_norm | |
}) | |
tqdm_meter.update() | |
return epoch_loss, num_correct / total | |
if __name__ == "__main__": | |
# set random seed | |
seed = 123 | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
# set device | |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
# data set | |
data = MyDataset() | |
# data loader | |
data_loader = DataLoader(dataset=data, batch_size=64, shuffle=True) | |
# define model | |
model = MyModel() | |
# transfer model to device | |
model = model.to(device) | |
# xavier initialization | |
model.apply(initialize_weights) | |
# print model | |
tqdm.write(str(model)) | |
# number of trainable params | |
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
tqdm.write(f'number of trainable parameters: {num_params:,}') | |
# define criterion | |
criterion = nn.CrossEntropyLoss() | |
# define optimizer | |
optimizer = torch.optim.Adam(model.parameters()) | |
# train | |
losses, accs = [], [] | |
for epoch in range(1, 10 + 1): | |
loss, acc = train(model, data_loader, criterion, optimizer, epoch, device) | |
losses.append(loss) | |
accs.append(acc) | |
# plot | |
with plt.xkcd(): | |
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10, 4), dpi=160) | |
axs[0].plot(range(1, 10 + 1), losses, marker='o', color='xkcd:violet') | |
axs[0].set_ylabel('Loss', fontsize=18) | |
axs[0].set_xlabel('Epoch', fontsize=18) | |
# axs[0].grid(alpha=0.4, linestyle='--') # grid does not work with plt.xkcd | |
axs[1].plot(range(1, 10 + 1), accs, marker='*', color='xkcd:red orange') | |
axs[1].set_ylabel('Accuracy', fontsize=18) | |
axs[1].set_xlabel('Epoch', fontsize=18) | |
# axs[1].grid(alpha=0.4, linestyle='--') # grid does not work with plt.xkcd | |
plt.tight_layout() | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment