Skip to content

Instantly share code, notes, and snippets.

@datakop
Last active April 26, 2017 23:25
Show Gist options
  • Save datakop/9c5ca511526c2fd7b8ccbb4edce734fc to your computer and use it in GitHub Desktop.
Save datakop/9c5ca511526c2fd7b8ccbb4edce734fc to your computer and use it in GitHub Desktop.
require 'nn';
net = nn.Sequential()
conv1 = nn.SpatialConvolution(1,64,3,3,2,2,1,1)
conv2 = nn.SpatialConvolution(64,128,3,3,1,1,1,1)
conv3 = nn.SpatialConvolution(128,128,3,3,2,2,1,1)
conv4 = nn.SpatialConvolution(128,256,3,3,1,1,1,1)
conv5 = nn.SpatialConvolution(256,256,3,3,2,2,1,1)
conv6 = nn.SpatialConvolution(256,512,3,3,1,1,1,1)
conv11 = nn.SpatialConvolution(512,512,3,3,1,1,1,1)
conv12 = nn.SpatialConvolution(512,256,3,3,1,1,1,1)
conv13 = nn.SpatialConvolution(256,128,3,3,1,1,1,1)
conv14 = nn.SpatialConvolution(128,64,3,3,1,1,1,1)
conv15 = nn.SpatialConvolution(64,64,3,3,1,1,1,1)
conv16 = nn.SpatialConvolution(64,32,3,3,1,1,1,1)
conv17 = nn.SpatialConvolution(32,3,3,3,1,1,1,1)
bn1 = nn.SpatialBatchNormalization(64, 1e-05, 0.1, True)
bn2 = nn.SpatialBatchNormalization(128, 1e-05, 0.1, True)
bn3 = nn.SpatialBatchNormalization(128, 1e-05, 0.1, True)
bn4= nn.SpatialBatchNormalization(256, 1e-05, 0.1, True)
bn5 = nn.SpatialBatchNormalization(256, 1e-05, 0.1, True)
bn6 = nn.SpatialBatchNormalization(512, 1e-05, 0.1, True)
bn11 = nn.SpatialBatchNormalization(512, 1e-05, 0.1, True)
bn12 = nn.SpatialBatchNormalization(256, 1e-05, 0.1, True)
bn13 = nn.SpatialBatchNormalization(128, 1e-05, 0.1, True)
bn14 = nn.SpatialBatchNormalization(64, 1e-05, 0.1, True)
bn15 = nn.SpatialBatchNormalization(64, 1e-05, 0.1, True)
bn16 = nn.SpatialBatchNormalization(32, 1e-05, 0.1, True)
ups1 = nn.SpatialUpSamplingBilinear(2)
ups2 = nn.SpatialUpSamplingBilinear(2)
ups3 = nn.SpatialUpSamplingBilinear(2)
-- Low-Level Features network
net:add(conv1)
net:add(bn1)
net:add(nn.ReLU())
net:add(conv2)
net:add(bn2)
net:add(nn.ReLU())
net:add(conv3)
net:add(bn3)
net:add(nn.ReLU())
net:add(conv4)
net:add(bn4)
net:add(nn.ReLU())
net:add(conv5)
net:add(bn5)
net:add(nn.ReLU())
net:add(conv6)
net:add(bn6)
net:add(nn.ReLU())
-- Middle-Level Features network
net:add(conv11)
net:add(bn11)
net:add(nn.ReLU())
net:add(conv12)
net:add(bn12)
net:add(nn.ReLU())
-- Colorization network
net:add(conv13)
net:add(bn13)
net:add(nn.ReLU())
net:add(ups1)
net:add(conv14)
net:add(bn14)
net:add(nn.ReLU())
net:add(conv15)
net:add(bn15)
net:add(nn.ReLU())
net:add(ups2)
net:add(conv16)
net:add(bn16)
net:add(nn.ReLU())
net:add(conv17)
net:add(nn.Sigmoid())
net:add(ups3)
modules = {
conv1=conv1,
conv2=conv2,
conv3=conv3,
conv4=conv4,
conv5=conv5,
conv6=conv6,
conv11=conv11,
conv12=conv12,
conv13=conv13,
conv14=conv14,
conv15=conv15,
conv16=conv16,
conv17=conv17,
bn1=bn1,
bn2=bn2,
bn3=bn3,
bn4=bn4,
bn5=bn5,
bn6=bn6,
bn11=bn11,
bn12=bn12,
bn13=bn13,
bn14=bn14,
bn15=bn15,
bn16=bn16}
npy4th = require 'npy4th'
function string.starts(String,Start)
return string.sub(String,1,string.len(Start))==Start
end
net:evaluate()
-- for k, v in pairs(modules) do
-- if string.starts(k, "conv") then
-- v.weight = npy4th.loadnpy(
-- string.format("/mnt/hdd/b.kopin/model/weights_train/module.%s.weight.npy", k)
-- ):double()
-- v.bias = npy4th.loadnpy(
-- string.format("/mnt/hdd/b.kopin/model/weights_train/module.%s.bias.npy", k)
-- ):double()
-- end
-- if string.starts(k, "bn") then
-- v.weight = npy4th.loadnpy(
-- string.format("/mnt/hdd/b.kopin/model/weights_train/module.%s.weight.npy", k)
-- ):double()
-- v.bias = npy4th.loadnpy(
-- string.format("/mnt/hdd/b.kopin/model/weights_train/module.%s.bias.npy", k)
-- ):double()
-- v.running_mean = npy4th.loadnpy(
-- string.format("/mnt/hdd/b.kopin/model/weights_train/module.%s.running_mean.npy", k)
-- ):double()
-- v.running_var = npy4th.loadnpy(
-- string.format("/mnt/hdd/b.kopin/model/weights_train/module.%s.running_var.npy", k)
-- ):double()
-- end
-- end
require 'hdf5'
model_hdf5 = hdf5.open('/mnt/hdd/b.kopin/model/model.h5', 'r')
for k, v in pairs(modules) do
if string.starts(k, "conv") then
v.weight = model_hdf5:read(string.format('module.%s.weight', k)):all()
v.bias = model_hdf5:read(string.format('module.%s.bias', k)):all()
end
if string.starts(k, "bn") then
v.weight = model_hdf5:read(string.format('module.%s.weight', k)):all()
v.bias = model_hdf5:read(string.format('module.%s.bias', k)):all()
v.running_mean = model_hdf5:read(string.format('module.%s.running_mean', k)):all()
v.running_var = model_hdf5:read(string.format('module.%s.running_var', k)):all()
end
end
model_hdf5:close()
require "torch"
require "image"
-- convert rgb to grayscale by averaging channel intensities
function rgb2gray(im)
-- Image.rgb2y uses a different weight mixture
local dim, w, h = im:size()[1], im:size()[2], im:size()[3]
if dim ~= 3 then
print('<error> expected 3 channels')
return im
end
-- a cool application of tensor:select
local r = im:select(1, 1)
local g = im:select(1, 2)
local b = im:select(1, 3)
local z = torch.Tensor(w, h):zero()
-- z = z + 0.21r
z = z:add(0.21, r)
z = z:add(0.72, g)
z = z:add(0.07, b)
return z
end
input = image.load("/mnt/hdd/b.kopin/tests/test/23.jpg", 3)
input_grey = rgb2gray(input)
out = net:forward(torch.reshape(input_grey,torch.LongStorage{1,1,
input_grey:size()[input_grey:size():size()],
input_grey:size()[input_grey:size():size()-1]
}))[1]
-- out = net:forward(torch.reshape(input_grey, torch.LongStorage{1,1,224,224}))
-- itorch.image(torch.reshape(input2, torch.LongStorage{1,1,224,224}))
itorch.image(out)
input_lab = image.rgb2lab(input)
out_lab = image.rgb2lab(out:clone())
h = out:size()[out:size():size() - 1]
w = out:size()[out:size():size()]
input_l = input_lab[{{1},{},{}}][1]:clone()
out_l = out_lab:clone()[1]
input_l_scaled = image.scale(input_l, w, h, "bilinear") -- image.scale(input_l, out_l)
out_lab[{{1},{},{}}] = input_l_scaled
out_new = image.lab2rgb(out_lab)
itorch.image(out_new)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment