Skip to content

Instantly share code, notes, and snippets.

@kali
Created February 13, 2019 19:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kali/9357e58956ecef7d3dcb2e88530c61c7 to your computer and use it in GitHub Desktop.
Save kali/9357e58956ecef7d3dcb2e88530c61c7 to your computer and use it in GitHub Desktop.
extern crate flate2;
extern crate image;
extern crate ndarray;
extern crate protobuf;
extern crate tar;
#[allow(unused_imports)]
#[macro_use]
extern crate tract_core;
use std::{fs, io, path};
use ndarray::prelude::*;
use tract_core::*;
use tract_onnx::pb::TensorProto;
use tract_onnx::*;
const MODEL: &[u8] = include_bytes!("../squeezenet/model.onnx");
fn main() {
let mut model = tract_onnx::model::for_reader(MODEL).unwrap();
model = model.into_optimized().unwrap();
let plan = SimplePlan::new(&model).unwrap();
let labels = load_labels();
let (input, expected) = load_dataset("squeezenet/test_data_set_10");
println!("input: {:?}", input);
let computed = plan.run(input).unwrap();
println!("computed: {:?}", computed);
println!("expected: {:?}", expected);
let label_id = computed[0]
.to_array_view::<f32>()
.unwrap()
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(0u32.cmp(&1)))
.unwrap()
.0;
let label = &labels[label_id];
println!("computed label(id = {}): {}", label_id, label);
let mut img = load_image("dog.png");
println!("input dog: {:?}", img);
normalize_images(&mut img.view_mut());
let computed = plan.run(tvec![img.into()]).unwrap();
println!("output dog: {}", computed[0].dump(true).unwrap());
let label_id = computed[0]
.to_array_view::<f32>()
.unwrap()
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(0u32.cmp(&1)))
.unwrap()
.0;
let label = &labels[label_id];
println!("computed label(id = {}): {}", label_id, label);
}
pub fn load_dataset(path: &str) -> (TVec<Tensor>, TVec<Tensor>) {
(
load_half_dataset("input", path),
load_half_dataset("output", path),
)
}
pub fn load_half_dataset(prefix: &str, path: &str) -> TVec<Tensor> {
let mut vec = tvec!();
let len = fs::read_dir(path)
.map_err(|e| format!("accessing {:?}, {:?}", path, e))
.unwrap()
.filter(|d| {
d.as_ref()
.unwrap()
.file_name()
.to_str()
.unwrap()
.starts_with(prefix)
})
.count();
for i in 0..len {
let filename = format!("{}/{}_{}.pb", path, prefix, i);
let mut file = fs::File::open(filename)
.map_err(|e| format!("accessing {:?}, {:?}", path, e))
.unwrap();
let tensor: TensorProto = ::protobuf::parse_from_reader(&mut file).unwrap();
vec.push(tensor.tractify().unwrap())
}
vec
}
pub fn load_image<P: AsRef<path::Path>>(p: P) -> Array4<f32> {
let image = ::image::open(&p).unwrap().to_rgb();
let resized = ::image::imageops::resize(&image, 224, 224, ::image::FilterType::Triangle);
Array4::from_shape_fn((1, 3, 224, 224), |(_, c, y, x)| {
resized[(x as _, y as _)][c] as f32 // 255.0
})
}
pub fn normalize_images(images: &mut ArrayViewMut4<f32>) {
// https://mxnet.incubator.apache.org/api/python/gluon/data.html#mxnet.gluon.data.vision.transforms.Normalize
let means = [0.485, 0.456, 0.406];
let stds = [0.229, 0.224, 0.225];
for mut image in images.outer_iter_mut() {
for (ix, mut chan) in image.outer_iter_mut().enumerate() {
chan -= means[ix];
chan /= stds[ix];
}
}
}
pub fn load_labels() -> Vec<String> {
use std::io::BufRead;
io::BufReader::new(fs::File::open("labels.txt").unwrap())
.lines()
.collect::<::std::io::Result<Vec<String>>>()
.unwrap()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment