Skip to content

Instantly share code, notes, and snippets.

@yongjun823
Created September 18, 2019 11:27
Show Gist options
  • Save yongjun823/ee387e8f3b424c310bbff80222c0d6e0 to your computer and use it in GitHub Desktop.
Save yongjun823/ee387e8f3b424c310bbff80222c0d6e0 to your computer and use it in GitHub Desktop.
PRSRGAN test
import torch
import torch.nn as nn
import torchvision.models as models
from pprint import pprint
class Net(nn.Module):
def __init__(self):
super().__init__()
model = models.vgg19(pretrained=False)
pprint(list(model.children()))
model = list(model.children())[:-2]
model = model[0]
model = list(model.children())
vgg_arr = []
for xx in model:
if 'Max' not in xx.__class__.__name__:
vgg_arr.append(xx)
self.net = nn.Sequential(*vgg_arr)
def forward(self, x):
x = self.net(x)
return x
device = torch.device("cuda:0")
model = Net().to(device)
model.eval()
t = torch.randn((1, 3, 224, 224)).to(device)
out = model(t)
print(out.shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment