Created
February 13, 2019 19:32
-
-
Save kali/9357e58956ecef7d3dcb2e88530c61c7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
extern crate flate2; | |
extern crate image; | |
extern crate ndarray; | |
extern crate protobuf; | |
extern crate tar; | |
#[allow(unused_imports)] | |
#[macro_use] | |
extern crate tract_core; | |
use std::{fs, io, path}; | |
use ndarray::prelude::*; | |
use tract_core::*; | |
use tract_onnx::pb::TensorProto; | |
use tract_onnx::*; | |
const MODEL: &[u8] = include_bytes!("../squeezenet/model.onnx"); | |
fn main() { | |
let mut model = tract_onnx::model::for_reader(MODEL).unwrap(); | |
model = model.into_optimized().unwrap(); | |
let plan = SimplePlan::new(&model).unwrap(); | |
let labels = load_labels(); | |
let (input, expected) = load_dataset("squeezenet/test_data_set_10"); | |
println!("input: {:?}", input); | |
let computed = plan.run(input).unwrap(); | |
println!("computed: {:?}", computed); | |
println!("expected: {:?}", expected); | |
let label_id = computed[0] | |
.to_array_view::<f32>() | |
.unwrap() | |
.iter() | |
.enumerate() | |
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(0u32.cmp(&1))) | |
.unwrap() | |
.0; | |
let label = &labels[label_id]; | |
println!("computed label(id = {}): {}", label_id, label); | |
let mut img = load_image("dog.png"); | |
println!("input dog: {:?}", img); | |
normalize_images(&mut img.view_mut()); | |
let computed = plan.run(tvec![img.into()]).unwrap(); | |
println!("output dog: {}", computed[0].dump(true).unwrap()); | |
let label_id = computed[0] | |
.to_array_view::<f32>() | |
.unwrap() | |
.iter() | |
.enumerate() | |
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(0u32.cmp(&1))) | |
.unwrap() | |
.0; | |
let label = &labels[label_id]; | |
println!("computed label(id = {}): {}", label_id, label); | |
} | |
pub fn load_dataset(path: &str) -> (TVec<Tensor>, TVec<Tensor>) { | |
( | |
load_half_dataset("input", path), | |
load_half_dataset("output", path), | |
) | |
} | |
pub fn load_half_dataset(prefix: &str, path: &str) -> TVec<Tensor> { | |
let mut vec = tvec!(); | |
let len = fs::read_dir(path) | |
.map_err(|e| format!("accessing {:?}, {:?}", path, e)) | |
.unwrap() | |
.filter(|d| { | |
d.as_ref() | |
.unwrap() | |
.file_name() | |
.to_str() | |
.unwrap() | |
.starts_with(prefix) | |
}) | |
.count(); | |
for i in 0..len { | |
let filename = format!("{}/{}_{}.pb", path, prefix, i); | |
let mut file = fs::File::open(filename) | |
.map_err(|e| format!("accessing {:?}, {:?}", path, e)) | |
.unwrap(); | |
let tensor: TensorProto = ::protobuf::parse_from_reader(&mut file).unwrap(); | |
vec.push(tensor.tractify().unwrap()) | |
} | |
vec | |
} | |
pub fn load_image<P: AsRef<path::Path>>(p: P) -> Array4<f32> { | |
let image = ::image::open(&p).unwrap().to_rgb(); | |
let resized = ::image::imageops::resize(&image, 224, 224, ::image::FilterType::Triangle); | |
Array4::from_shape_fn((1, 3, 224, 224), |(_, c, y, x)| { | |
resized[(x as _, y as _)][c] as f32 // 255.0 | |
}) | |
} | |
pub fn normalize_images(images: &mut ArrayViewMut4<f32>) { | |
// https://mxnet.incubator.apache.org/api/python/gluon/data.html#mxnet.gluon.data.vision.transforms.Normalize | |
let means = [0.485, 0.456, 0.406]; | |
let stds = [0.229, 0.224, 0.225]; | |
for mut image in images.outer_iter_mut() { | |
for (ix, mut chan) in image.outer_iter_mut().enumerate() { | |
chan -= means[ix]; | |
chan /= stds[ix]; | |
} | |
} | |
} | |
pub fn load_labels() -> Vec<String> { | |
use std::io::BufRead; | |
io::BufReader::new(fs::File::open("labels.txt").unwrap()) | |
.lines() | |
.collect::<::std::io::Result<Vec<String>>>() | |
.unwrap() | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment