-
-
Save lpuchallafiore/510a464254c2e0130e510fb137b637b6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(* For each image in some dataset directory, predict whether it has the ImageNet class | |
python or camel. | |
The pre-trained weight file can be found at: | |
https://github.com/LaurentMazare/ocaml-torch/releases/download/v0.1-unstable/resnet18.ot | |
*) | |
open Base | |
open Torch | |
open Torch_vision | |
let camel_idx = 354 | |
let python_idx = 62 | |
(* Prints the proportion of python images in a directory. *) | |
let process model ~dir = | |
(* Load all the images in a directory. *) | |
let images = Imagenet.load_images ~dir in | |
Tensor.print_shape images ~name:dir; | |
(* Run the model on the images and compute all class logits. *) | |
let logits = Layer.forward_ model images ~is_training:false in | |
(* Isolate the logits for python and camel classes. *) | |
let python_logits = Tensor.narrow logits ~dim:1 ~start:python_idx ~length:1 in | |
let camel_logits = Tensor.narrow logits ~dim:1 ~start:camel_idx ~length:1 in | |
let python_proba = | |
(* Compute python >= camel and the mean to get proportion of python images. *) | |
Tensor.(mean (ge1 python_logits camel_logits |> to_type ~type_:(T Float))) | |
|> Tensor.to_float0_exn | |
in | |
Stdio.printf "Python: %.2f%%\n%!" (100. *. python_proba); | |
Stdio.printf "Camel : %.2f%%\n%!" (100. *. (1. -. python_proba)) | |
let () = | |
let model_file, python_dir, camel_dir = | |
match Sys.argv with | |
| [| _; model_file; python_dir; camel_dir |] -> model_file, python_dir, camel_dir | |
| _ -> Printf.failwithf "usage: %s resnet18.ot python_dir camel_dir" Sys.argv.(0) () | |
in | |
let vs = Var_store.create ~name:"rn" ~device:Cpu () in | |
let model = Resnet.resnet18 vs ~num_classes:1000 in | |
Stdio.printf "Loading weights from %s\n%!" Sys.argv.(1); | |
Serialize.load_multi_ ~named_tensors:(Var_store.all_vars vs) ~filename:model_file; | |
process model ~dir:python_dir; | |
process model ~dir:camel_dir |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment