Skip to content

Instantly share code, notes, and snippets.

@jzstark jzstark/#readme.md
Last active Nov 7, 2018

Embed
What would you like to do?
FST with CGraph module

Fast Neural Style Transfer - CG

Neural Style Transfer is the process of using Deep Neural Networks to migrate the semantic content of one image to different styles. Fast Neural Style Transfer (FST) can finish this process in the order of seconds. The Computation Graph (CG) module is used to drastically reduce memory usage.

Usage

This gist implements Fast Style Transfer in Owl, and provides a simple interfaces to use. Here is an example:

#zoo "f57888f85f337a154b380eb8fea80f95"

FST_CG.list_styles ();; (* show all supported styles *)
FST_CG.run ~style:1 "path/to/content_img.png" "path/to/output_img.jpg" 

The run function mainly takes one content image and output to a new image file, the name of which is designated by the user. The image chould be of any popular formats: jpeg, png, etc. This gist contains exemplar content images for you to use.

Current we support six art styles:

  1. "Udnie" by Francis Picabia
  2. "The Great Wave off Kanagawa" by Hokusai
  3. "Rain Princess" by Leonid Afremov
  4. "La Muse" by Picasso
  5. "The Scream" by Edvard Munch
  6. "The shipwreck of the Minotaur" by J. M. W. Turner

Prerequisite

This application relies on the tool ImageMagick to manipulate image format conversion and resizing. Please make sure it is installed. E.g. on Ubuntu:

sudo apt-get install imagemagick

Limit

FST is limited in styles. If you want to try some styles that are not listed here, please refer to our Neural Style Transfer gist.

open Owl
#zoo "86a1748bbc898f2e42538839edba00e1" (* ImageUtils *)
module N = Dense.Ndarray.S
module E = Owl_computation_cpu_engine.Make (N)
module Compiler = Owl_neural_compiler.Make (E)
module Neural = Compiler.Neural
module Graph = Compiler.Neural.Graph
module AD = Compiler.Neural.Algodiff
module Engine = Compiler.Engine
open Neural
open Graph
open AD
let pack x = Engine.pack_arr x |> AD.pack_arr
let unpack x = AD.unpack_arr x |> Engine.unpack_arr
(** Network Structure *)
let conv2d_layer ?(relu=true) kernel stride nn =
let result =
conv2d ~padding:SAME kernel stride nn
|> normalisation ~decay:0. ~training:true ~axis:3
in
match relu with
| true -> (result |> activation Activation.Relu)
| _ -> result
let conv2d_trans_layer kernel stride nn =
transpose_conv2d ~padding:SAME kernel stride nn
|> normalisation ~decay:0. ~training:true ~axis:3
|> activation Activation.Relu
let residual_block wh nn =
let tmp = conv2d_layer [|wh; wh; 128; 128|] [|1;1|] nn
|> conv2d_layer ~relu:false [|wh; wh; 128; 128|] [|1;1|]
in
add [|nn; tmp|]
(* perfectly balanced -- like everything should be. *)
let make_network h w =
input [|h;w;3|]
|> conv2d_layer [|9;9;3;32|] [|1;1|]
|> conv2d_layer [|3;3;32;64|] [|2;2|]
|> conv2d_layer [|3;3;64;128|] [|2;2|]
|> residual_block 3
|> residual_block 3
|> residual_block 3
|> residual_block 3
|> residual_block 3
|> conv2d_trans_layer [|3;3;128;64|] [|2;2|]
|> conv2d_trans_layer [|3;3;64;32|] [|2;2|]
|> conv2d_layer ~relu:false [|9;9;32;3|] [|1;1|]
|> lambda (fun x -> Maths.((tanh x) * (pack_flt 150.) + (pack_flt 127.5)))
|> get_network
(* Image helper functions *)
let _convert img_name =
let base = Filename.basename img_name in
let prefix = Filename.remove_extension base in
let temp_img = Filename.temp_file prefix ".ppm"in
temp_img
let convert_img_to_ppm w h img_name =
let temp_img = _convert img_name in
let _ = Sys.command ("convert -resize " ^ (string_of_int w) ^
"x" ^ (string_of_int h) ^"\\! " ^
img_name ^ " " ^ temp_img) in
temp_img
let convert_arr_to_img d3array output_name =
let temp_img = _convert output_name in
let output = d3array in
ImageUtils.save_ppm_from_arr output temp_img;
let _ = Sys.command ("convert " ^ temp_img ^ " " ^ output_name) in
()
let get_img_shape img_name =
let temp_img = _convert img_name in
let _ = Sys.command ("convert " ^ img_name ^ " " ^ temp_img) in
let _, w, h, _ = ImageUtils._read_ppm temp_img in
w, h
(* Styles *)
let styles = [|"udnie"; "wave"; "rain_princess"; "la_muse"; "scream"; "wreck"|]
let make_style_htb () =
let h = Hashtbl.create 10 in
for i = 0 to (Array.length styles - 1) do
(* weight file: e.g. fst_udnie.weight *)
Hashtbl.add h i ("fst_" ^ styles.(i) ^ "_cg.weight")
done;
h
let style_htb = make_style_htb ()
let list_styles () =
let s = ref "" in
for i = 0 to (Array.length styles - 1) do
s := !s ^ Printf.sprintf "Style [%d] :\t %s\n" i styles.(i)
done;
let info = Printf.sprintf "Here are the usable styles:\n%s" !s in
print_endline info
(* FST service function *)
let run ?(style=0) content_img output_img =
let w, h = get_img_shape content_img in
let content_img = convert_img_to_ppm w h content_img in
let content_img = ImageUtils.(load_ppm content_img |> extend_dim) in
let nn = make_network h w in
Graph.init nn;
let style_file =
try Hashtbl.find style_htb style
with Not_found -> failwith "style does not exist; try to run `list_styles ()`"
in
Graph.load_weights nn (Owl_zoo_path.extend_zoo_path style_file);
let result = Compiler.model nn (pack content_img) |> unpack in
convert_arr_to_img result output_img
This file has been truncated, but you can view the full file.
View raw

(Sorry about that, but we can’t show files that are this big right now.)

View raw

(Sorry about that, but we can’t show files that are this big right now.)

View raw

(Sorry about that, but we can’t show files that are this big right now.)

View raw

(Sorry about that, but we can’t show files that are this big right now.)

View raw

(Sorry about that, but we can’t show files that are this big right now.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.