|
#!/usr/bin/env owl |
|
|
|
open Owl |
|
open Owl_types |
|
open Neural.S |
|
open Neural.S.Graph |
|
|
|
#zoo "51eaf74c65fa14c8c466ecfab2351bbd" (* Imagenet_cls*) |
|
#zoo "86a1748bbc898f2e42538839edba00e1" (* ImageUtils *) |
|
|
|
let channel_last = true (* The same in Keras Conv layer *) |
|
let include_top = true (* if false, no final Dense layer *) |
|
let img_size = 299 (* include_top = true means img_size have to be exact 299 *) |
|
|
|
let weight_file = Owl_zoo_path.extend_zoo_path "inception_owl.weight" |
|
|
|
let conv2d_bn ?(padding=SAME) kernel stride nn = |
|
conv2d ~padding kernel stride nn |
|
|> normalisation ~training:false ~axis:3 |
|
|> activation Activation.Relu |
|
|
|
let mix_typ1 in_shape bp_size nn = |
|
let branch1x1 = conv2d_bn [|1;1;in_shape;64|] [|1;1|] nn in |
|
let branch5x5 = nn |
|
|> conv2d_bn [|1;1;in_shape;48|] [|1;1|] |
|
|> conv2d_bn [|5;5;48;64|] [|1;1|] |
|
in |
|
let branch3x3dbl = nn |
|
|> conv2d_bn [|1;1;in_shape;64|] [|1;1|] |
|
|> conv2d_bn [|3;3;64;96|] [|1;1|] |
|
|> conv2d_bn [|3;3;96;96|] [|1;1|] |
|
in |
|
let branch_pool = nn |
|
|> avg_pool2d [|3;3|] [|1;1|] |
|
|> conv2d_bn [|1;1;in_shape; bp_size |] [|1;1|] |
|
in |
|
concatenate 3 [|branch1x1; branch5x5; branch3x3dbl; branch_pool|] |
|
|
|
let mix_typ3 nn = |
|
let branch3x3 = conv2d_bn [|3;3;288;384|] [|2;2|] ~padding:VALID nn in |
|
let branch3x3dbl = nn |
|
|> conv2d_bn [|1;1;288;64|] [|1;1|] |
|
|> conv2d_bn [|3;3;64;96|] [|1;1|] |
|
|> conv2d_bn [|3;3;96;96|] [|2;2|] ~padding:VALID |
|
in |
|
let branch_pool = max_pool2d [|3;3|] [|2;2|] ~padding:VALID nn in |
|
concatenate 3 [|branch3x3; branch3x3dbl; branch_pool|] |
|
|
|
let mix_typ4 size nn = |
|
let branch1x1 = conv2d_bn [|1;1;768;192|] [|1;1|] nn in |
|
let branch7x7 = nn |
|
|> conv2d_bn [|1;1;768;size|] [|1;1|] |
|
|> conv2d_bn [|1;7;size;size|] [|1;1|] |
|
|> conv2d_bn [|7;1;size;192|] [|1;1|] |
|
in |
|
let branch7x7dbl = nn |
|
|> conv2d_bn [|1;1;768;size|] [|1;1|] |
|
|> conv2d_bn [|7;1;size;size|] [|1;1|] |
|
|> conv2d_bn [|1;7;size;size|] [|1;1|] |
|
|> conv2d_bn [|7;1;size;size|] [|1;1|] |
|
|> conv2d_bn [|1;7;size;192|] [|1;1|] |
|
in |
|
let branch_pool = nn |
|
|> avg_pool2d [|3;3|] [|1;1|] (* padding = SAME *) |
|
|> conv2d_bn [|1;1; 768; 192|] [|1;1|] |
|
in |
|
concatenate 3 [|branch1x1; branch7x7; branch7x7dbl; branch_pool|] |
|
|
|
let mix_typ8 nn = |
|
let branch3x3 = nn |
|
|> conv2d_bn [|1;1;768;192|] [|1;1|] |
|
|> conv2d_bn [|3;3;192;320|] [|2;2|] ~padding:VALID |
|
in |
|
let branch7x7x3 = nn |
|
|> conv2d_bn [|1;1;768;192|] [|1;1|] |
|
|> conv2d_bn [|1;7;192;192|] [|1;1|] |
|
|> conv2d_bn [|7;1;192;192|] [|1;1|] |
|
|> conv2d_bn [|3;3;192;192|] [|2;2|] ~padding:VALID |
|
in |
|
let branch_pool = max_pool2d [|3;3|] [|2;2|] ~padding:VALID nn in |
|
concatenate 3 [|branch3x3; branch7x7x3; branch_pool|] |
|
|
|
let mix_typ9 input nn = |
|
let branch1x1 = conv2d_bn [|1;1;input;320|] [|1;1|] nn in |
|
let branch3x3 = conv2d_bn [|1;1;input;384|] [|1;1|] nn in |
|
let branch3x3_1 = branch3x3 |> conv2d_bn [|1;3;384;384|] [|1;1|] in |
|
let branch3x3_2 = branch3x3 |> conv2d_bn [|3;1;384;384|] [|1;1|] in |
|
let branch3x3 = concatenate 3 [| branch3x3_1; branch3x3_2 |] in |
|
let branch3x3dbl = nn |> conv2d_bn [|1;1;input;448|] [|1;1|] |> conv2d_bn [|3;3;448;384|] [|1;1|] in |
|
let branch3x3dbl_1 = branch3x3dbl |> conv2d_bn [|1;3;384;384|] [|1;1|] in |
|
let branch3x3dbl_2 = branch3x3dbl |> conv2d_bn [|3;1;384;384|] [|1;1|] in |
|
let branch3x3dbl = concatenate 3 [|branch3x3dbl_1; branch3x3dbl_2|] in |
|
let branch_pool = nn |> avg_pool2d [|3;3|] [|1;1|] |> conv2d_bn [|1;1;input;192|] [|1;1|] in |
|
concatenate 3 [|branch1x1; branch3x3; branch3x3dbl; branch_pool|] |
|
|
|
let make_network img_size = |
|
input [|img_size;img_size;3|] |
|
|> conv2d_bn [|3;3;3;32|] [|2;2|] ~padding:VALID |
|
|> conv2d_bn [|3;3;32;32|] [|1;1|] ~padding:VALID |
|
|> conv2d_bn [|3;3;32;64|] [|1;1|] |
|
|> max_pool2d [|3;3|] [|2;2|] ~padding:VALID |
|
|> conv2d_bn [|1;1;64;80|] [|1;1|] ~padding:VALID |
|
|> conv2d_bn [|3;3;80;192|] [|1;1|] ~padding:VALID |
|
|> max_pool2d [|3;3|] [|2;2|] ~padding:VALID |
|
|> mix_typ1 192 32 |> mix_typ1 256 64 |> mix_typ1 288 64 |
|
|> mix_typ3 |
|
|> mix_typ4 128 |> mix_typ4 160 |> mix_typ4 160 |> mix_typ4 192 |
|
|> mix_typ8 |
|
|> mix_typ9 1280 |> mix_typ9 2048 |
|
|> global_avg_pool2d |
|
|> linear 1000 ~act_typ:Activation.(Softmax 1) |
|
|> get_network |
|
|
|
(* input: name of input image; output: 1x1000 ndarray *) |
|
let infer img = |
|
let nn = make_network 299 in |
|
Graph.load_weights nn weight_file; |
|
let filename = String.split_on_char '/' img |> List.rev |> List.hd in |
|
let prefix = Filename.remove_extension filename in |
|
let tmp_img = Filename.temp_file prefix ".ppm" in |
|
let _ = Sys.command ("convert -resize 299x299\\! " ^ img ^ " " ^ tmp_img) in |
|
let img_ppm = ImageUtils.(load_ppm tmp_img |> extend_dim |> normalise) in |
|
Graph.model nn img_ppm |
|
|
|
(* input: 1x1000 ndarray; output: top-N inference result list, |
|
* each element in the form of [class: string; propability: float] *) |
|
let to_tuples ?(top=5) label = |
|
Imagenet_cls.to_tuples ~top label |
|
|
|
(* input: 1x1000 ndarray; output: top-N inference result as a json string *) |
|
let to_json ?(top=5) label = |
|
Imagenet_cls.to_json ~top label |
|
|
|
let test () = |
|
let example = Owl_zoo_path.extend_zoo_path "panda.png" in |
|
infer example |> to_json |> Printf.printf "%s\n" |