Skip to content

Instantly share code, notes, and snippets.

@BrambleXu
Created January 18, 2020 02:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save BrambleXu/cb835888d61beaafed3ff71e1a864a6e to your computer and use it in GitHub Desktop.
Save BrambleXu/cb835888d61beaafed3ff71e1a864a6e to your computer and use it in GitHub Desktop.
TensorBoard with PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
print(torch.__version__)
print(torchvision.__version__)
class Network(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)
self.fc1 = nn.Linear(in_features=12*4*4, out_features=120)
self.fc2 = nn.Linear(in_features=120, out_features=60)
self.out = nn.Linear(in_features=60, out_features=10)
def forward(self, t):
t = F.relu(self.conv1(t))
t = F.max_pool2d(t, kernel_size=2, stride=2)
t = F.relu(self.conv2(t))
t = F.max_pool2d(t, kernel_size=2, stride=2)
t = t.flatten(start_dim=1)
t = F.relu(self.fc1(t))
t = F.relu(self.fc2(t))
t = self.out(t)
return t
train_set = torchvision.datasets.FashionMNIST(
root='./data',
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor()
])
)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=100, shuffle=True)
# tensorboard setting
tb = SummaryWriter()
network = Network()
images, labels = next(iter(train_loader))
grid = torchvision.utils.make_grid(images)
tb.add_image('images', grid)
tb.add_graph(network, images)
tb.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment