Skip to content

Instantly share code, notes, and snippets.

@lpuchallafiore
Last active September 6, 2019 13:57
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save lpuchallafiore/510a464254c2e0130e510fb137b637b6 to your computer and use it in GitHub Desktop.
Save lpuchallafiore/510a464254c2e0130e510fb137b637b6 to your computer and use it in GitHub Desktop.
(* For each image in some dataset directory, predict whether it has the ImageNet class
python or camel.
The pre-trained weight file can be found at:
https://github.com/LaurentMazare/ocaml-torch/releases/download/v0.1-unstable/resnet18.ot
*)
open Base
open Torch
open Torch_vision
let camel_idx = 354
let python_idx = 62
(* Prints the proportion of python images in a directory. *)
let process model ~dir =
(* Load all the images in a directory. *)
let images = Imagenet.load_images ~dir in
Tensor.print_shape images ~name:dir;
(* Run the model on the images and compute all class logits. *)
let logits = Layer.forward_ model images ~is_training:false in
(* Isolate the logits for python and camel classes. *)
let python_logits = Tensor.narrow logits ~dim:1 ~start:python_idx ~length:1 in
let camel_logits = Tensor.narrow logits ~dim:1 ~start:camel_idx ~length:1 in
let python_proba =
(* Compute python >= camel and the mean to get proportion of python images. *)
Tensor.(mean (ge1 python_logits camel_logits |> to_type ~type_:(T Float)))
|> Tensor.to_float0_exn
in
Stdio.printf "Python: %.2f%%\n%!" (100. *. python_proba);
Stdio.printf "Camel : %.2f%%\n%!" (100. *. (1. -. python_proba))
let () =
let model_file, python_dir, camel_dir =
match Sys.argv with
| [| _; model_file; python_dir; camel_dir |] -> model_file, python_dir, camel_dir
| _ -> Printf.failwithf "usage: %s resnet18.ot python_dir camel_dir" Sys.argv.(0) ()
in
let vs = Var_store.create ~name:"rn" ~device:Cpu () in
let model = Resnet.resnet18 vs ~num_classes:1000 in
Stdio.printf "Loading weights from %s\n%!" Sys.argv.(1);
Serialize.load_multi_ ~named_tensors:(Var_store.all_vars vs) ~filename:model_file;
process model ~dir:python_dir;
process model ~dir:camel_dir
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment