This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#方法一 | |
network1 = nn.Sequential( | |
nn.Flatten(start_dim=1) | |
,nn.Linear(in_features, out_features) | |
,nn.Linear(out_features, out_classes) | |
) | |
network1 | |
'''Sequential( | |
(0): Flatten() | |
(1): Linear(in_features=784, out_features=392, bias=True) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torch.nn.functional as F | |
import torchvision | |
import torchvision.transforms as transforms | |
from torch.utils.data import DataLoader | |
from torch.utils.tensorboard import SummaryWriter | |
from IPython.display import display, clear_output |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import OrderedDict | |
from collections import namedtuple | |
from itertools import product | |
class RunBuilder(): | |
@staticmethod | |
def get_runs(params): | |
Run = namedtuple('Run', params.keys()) | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from itertools import product | |
parameters = dict( | |
lr = [.01, .001] | |
,batch_size = [100, 1000] | |
,shuffle = [True, False] | |
) | |
for lr, batch_size, shuffle in product(*para_values): | |
print(lr, batch_size, shuffle) | |
'''0.01 100 True |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.utils.tensorboard import SummaryWriter | |
#first example | |
tb = SummaryWriter() | |
images, labels = next(iter(train_loader)) | |
grid = torchvision.utils.make_grid(images) | |
tb.add_image('images', grid) | |
tb.add_graph(network, images) #show the network |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scikitplot as skplt | |
import matplotlib.pyplot as plt | |
def get_all_preds(model, loader): | |
all_preds = torch.tensor([]) | |
for batch in loader: | |
images, labels = batch | |
preds = model(images) | |
all_preds = torch.cat( |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
for epoch in range(5): | |
total_loss = 0 | |
total_correct = 0 | |
for batch in train_loader: | |
images, labels = batch | |
preds = network(images) | |
loss = F.cross_entropy(preds, labels) | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_num_correct(preds, labels): | |
return preds.argmax(dim=1).eq(labels).sum().item() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ResNet(nn.Module): | |
def __init__(self, block, n_size, num_classes=10): | |
super(ResNet, self).__init__() | |
self.inplane = 16 | |
self.conv1 = nn.Conv2d(3, self.inplane, kernel_size=3, stride=1, padding=1, bias=False) | |
self.bn1 = nn.BatchNorm2d(self.inplane) | |
self.relu = nn.ReLU(inplace=True) | |
self.layer1 = self._make_layer(block, 16, blocks=n_size, stride=1) | |
self.layer2 = self._make_layer(block, 32, blocks=n_size, stride=2) | |
self.layer3 = self._make_layer(block, 64, blocks=n_size, stride=2) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#当想要用大批量进行训练,但是 GPU 资源有限,此时可以通过梯度累加(accumulating gradients)的方式进行。 | |
#梯度累加的基本思想在于,在优化器更新参数前,也就是执行 optimizer.step() 前,进行多次反向传播,是的梯度累计值自动保存在 parameter.grad 中,最后使用累加的梯度进行参数更新。 | |
#这个在PyTorch中特别容易实现,因为PyTorch中,梯度值本身会保留,除非我们调用 model.zero_grad() 或 optimizer.zero_grad()。 | |
model.zero_grad() # 重置保存梯度值的张量 | |
for i, (inputs, labels) in enumerate(training_set): | |
predictions = model(inputs) # 前向计算 | |
loss = loss_function(predictions, labels) # 计算损失函数 | |
loss.backward() # 计算梯度 |
NewerOlder