Skip to content

Instantly share code, notes, and snippets.

@trsqxyz
Last active June 22, 2016 13:55
Show Gist options
  • Save trsqxyz/16a5576475fc97489daaee92380a1d7b to your computer and use it in GitHub Desktop.
Save trsqxyz/16a5576475fc97489daaee92380a1d7b to your computer and use it in GitHub Desktop.
colorize.diff
diff --git a/colorize.lua b/colorize.lua
index 060ef8b..dd65355 100644
--- a/colorize.lua
+++ b/colorize.lua
@@ -16,9 +16,11 @@
require 'nn'
require 'nngraph'
require 'image'
+require 'lfs'
-local infile = arg[1]
-local outfile = arg[2] or 'out.png'
+local infile = ''
+local outfile = 'out.png'
+local dirname = 'YOUR DIR PATH'
local d = torch.load( 'colornet.t7' )
local datamean = d.mean
@@ -35,16 +37,28 @@ local function pred2rgb( x, data )
return image.yuv2rgb( torch.cat( x, O[{{2,3},{},{}}], 1 ) )
end
-local I = image.load( infile )
-if I:size(1)==3 then I = image.rgb2y(I) end
-local X2 = image.scale( I, torch.round(I:size(3)/8)*8, torch.round(I:size(2)/8)*8 ):add(-datamean):float()
-local X1 = image.scale( X2, 224, 224 ):float()
-X1 = X1:reshape( 1, X1:size(1), X1:size(2), X1:size(3) )
-X2 = X2:reshape( 1, X2:size(1), X2:size(2), X2:size(3) )
-model.forwardnodes[9].data.module.modules[3].nfeatures = X2:size(3)/8
-model.forwardnodes[9].data.module.modules[4].nfeatures = X2:size(4)/8
-
-image.save( outfile, pred2rgb( I:float(), model:forward( {X1, X2} ) ) )
+for file in lfs.dir(dirname) do
+ if string.find(file, '.', 1, true) then
+
+ else
+ print(file)
+ for infile in lfs.dir(dirname..file) do
+ if string.find(infile, '.jpg', 1, true) then
+ outfile = 'out'..infile..'_'..file..'.jpg'
+ local I = image.load( dirname..file..'/'..infile )
+ if I:size(1)==3 then I = image.rgb2y(I) end
+ local X2 = image.scale( I, torch.round(I:size(3)/8)*8, torch.round(I:size(2)/8)*8 ):add(-datamean):float()
+ local X1 = image.scale( X2, 224, 224 ):float()
+ X1 = X1:reshape( 1, X1:size(1), X1:size(2), X1:size(3) )
+ X2 = X2:reshape( 1, X2:size(1), X2:size(2), X2:size(3) )
+ model.forwardnodes[9].data.module.modules[3].nfeatures = X2:size(3)/8
+ model.forwardnodes[9].data.module.modules[4].nfeatures = X2:size(4)/8
+ image.save( outfile, pred2rgb( I:float(), model:forward( {X1, X2} ) ) )
+ collectgarbage('collect')
+ end
+ end
+ end
+end
@trsqxyz
Copy link
Author

trsqxyz commented Jun 22, 2016

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment