Skip to content

Instantly share code, notes, and snippets.

@jeffwillette
Last active Apr 13, 2021
Embed
What would you like to do?
Reproduction of an issue causing a core dump on Pytorch (1.8.1+cu111)
from argparse import Namespace
from typing import Any
import torch
from torch import nn
from torch.nn import functional as F
T = torch.Tensor
class MWNet(nn.Module):
def __init__(self, h_dim: int = 100) -> None:
super().__init__()
self.h_diom = h_dim
self.layer = nn.Sequential(nn.Linear(1, 100), nn.ReLU(inplace=True), nn.Linear(100, 1))
def forward(self, x: T) -> T:
out = self.layer(x)
return torch.softmax(out, dim=-1) # type: ignore
class Net(nn.Module):
def __init__(self, n_layers: int, in_dim: int, h_dim: int, classes: int):
super().__init__()
lyrs: Any = [nn.Linear(in_dim, h_dim)]
for i in range(n_layers):
lyrs.extend([nn.Linear(h_dim, h_dim), nn.ReLU(inplace=True), nn.Dropout(0.1)])
lyrs.append(nn.Linear(h_dim, 2))
self.layers = nn.Sequential(*lyrs)
def forward(self, x: T) -> T:
logit = self.layers(x)
return torch.softmax(logit, dim=1)
def snrandom(train: T, test: T, total_x: int, ft: int) -> None:
model = Net(5, ft, 128, 2).cuda() # type: ignore
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
mw_net = MWNet().cuda()
mw_opt = torch.optim.Adam(mw_net.parameters())
model.train()
for epoch in range(200):
for x in train:
x, y = x.cuda(), torch.randint(0, 2, (x.size(0),)).cuda()
opt.zero_grad()
mw_opt.zero_grad()
yhat = model(x)
loss = F.cross_entropy(yhat, y.long(), reduction='none')
w = mw_net(loss.unsqueeze(-1)).squeeze(-1)
loss = (loss * w).sum()
print(f"loss: {loss.item()}")
loss.backward()
opt.step()
mw_opt.step()
model.eval()
correct, n = 0.0, 0.0
with torch.no_grad():
for (x, y) in test:
c = model(x.cuda(), y, 10)
correct += c.item()
n += x.size(0)
ood_acc = (correct / n) # type: ignore
print("test set: ", ood_acc)
if __name__ == "__main__":
args = Namespace(batch_size=32)
train, test = torch.randn(100, 32, 2), torch.randn(100, 32, 2)
snrandom(train, test, 100, 2) # type: ignore
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment