Skip to content

Instantly share code, notes, and snippets.

@datakop
Last active April 26, 2017 20:56
Show Gist options
  • Save datakop/8ce2c9c543bb72bf9f8606058145d479 to your computer and use it in GitHub Desktop.
Save datakop/8ce2c9c543bb72bf9f8606058145d479 to your computer and use it in GitHub Desktop.
from torch.autograd import Variable
from torch import nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=2, padding=1)
self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1)
self.conv4 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1)
self.conv5 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1)
self.conv6 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1)
self.conv11 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1)
self.conv12 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, stride=1, padding=1)
self.conv13 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, stride=1, padding=1)
self.conv14 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1)
self.conv15 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
self.conv16 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1)
self.conv17 = nn.Conv2d(in_channels=32, out_channels=3, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(num_features=64)
self.bn2 = nn.BatchNorm2d(num_features=128)
self.bn3 = nn.BatchNorm2d(num_features=128)
self.bn4= nn.BatchNorm2d(num_features=256)
self.bn5 = nn.BatchNorm2d(num_features=256)
self.bn6 = nn.BatchNorm2d(num_features=512)
self.bn11 = nn.BatchNorm2d(num_features=512)
self.bn12 = nn.BatchNorm2d(num_features=256)
self.bn13 = nn.BatchNorm2d(num_features=128)
self.bn14 = nn.BatchNorm2d(num_features=64)
self.bn15 = nn.BatchNorm2d(num_features=64)
self.bn16 = nn.BatchNorm2d(num_features=32)
self.ups1 = nn.UpsamplingBillinear2d(scale_factor=2)
self.ups2 = nn.UpsamplingBillinear2d(scale_factor=2)
self.ups3 = nn.UpsamplingBillinear2d(scale_factor=2)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
x = F.relu(self.bn4(self.conv4(x)))
x = F.relu(self.bn5(self.conv5(x)))
x = F.relu(self.bn6(self.conv6(x)))
x = F.relu(self.bn11(self.conv11(x)))
x = F.relu(self.bn12(self.conv12(x)))
x = F.relu(self.bn13(self.conv13(x)))
x = self.ups1(x)
x = F.relu(self.bn14(self.conv14(x)))
x = F.relu(self.bn15(self.conv15(x)))
x = self.ups2(x)
x = F.relu(self.bn16(self.conv16(x)))
x = F.sigmoid(self.conv17(x))
x = self.ups3(x)
return x
net = Net()
# перед запуском нужно создать папку
path_to_model_weight_dir = "./model_weights"
import numpy
for name, data in net.state_dict().items():
numpy.save(path_to_model_weight_dir + "/%s" % name, data.numpy(), allow_pickle=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment