Created
February 16, 2019 11:07
-
-
Save kali/861965ae39c45497c3a7954115d2c3ab to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
extern crate flate2; | |
extern crate image; | |
extern crate ndarray; | |
extern crate protobuf; | |
extern crate tar; | |
#[allow(unused_imports)] | |
#[macro_use] | |
extern crate tract_core; | |
use std::{fs, io, path}; | |
use ndarray::prelude::*; | |
use tract_core::*; | |
use tract_onnx::pb::TensorProto; | |
use tract_onnx::*; | |
const MODEL: &[u8] = include_bytes!("../squeezenet1_1.onnx"); | |
fn main() { | |
let mut model = tract_onnx::model::for_reader(MODEL).unwrap(); | |
model = model.into_optimized().unwrap(); | |
let plan = SimplePlan::new(&model).unwrap(); | |
let labels = load_labels(); | |
for file in fs::read_dir(".").unwrap() { | |
let file = file.unwrap(); | |
if file.path().extension() == Some("png".as_ref()) { | |
let mut img = load_image(&*file.path().to_string_lossy()); | |
normalize_images(&mut img.view_mut()); | |
let computed = plan.run(tvec![img.into()]).unwrap(); | |
let label_id = computed[0] | |
.to_array_view::<f32>() | |
.unwrap() | |
.iter() | |
.enumerate() | |
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(0u32.cmp(&1))) | |
.unwrap() | |
.0; | |
let label = &labels[label_id]; | |
println!("{:?} -> computed label(id = {}): {}", file.path(), label_id, label); | |
} | |
} | |
} | |
pub fn load_dataset(path: &str) -> (TVec<Tensor>, TVec<Tensor>) { | |
( | |
load_half_dataset("input", path), | |
load_half_dataset("output", path), | |
) | |
} | |
pub fn load_half_dataset(prefix: &str, path: &str) -> TVec<Tensor> { | |
let mut vec = tvec!(); | |
let len = fs::read_dir(path) | |
.map_err(|e| format!("accessing {:?}, {:?}", path, e)) | |
.unwrap() | |
.filter(|d| { | |
d.as_ref() | |
.unwrap() | |
.file_name() | |
.to_str() | |
.unwrap() | |
.starts_with(prefix) | |
}) | |
.count(); | |
for i in 0..len { | |
let filename = format!("{}/{}_{}.pb", path, prefix, i); | |
let mut file = fs::File::open(filename) | |
.map_err(|e| format!("accessing {:?}, {:?}", path, e)) | |
.unwrap(); | |
let tensor: TensorProto = ::protobuf::parse_from_reader(&mut file).unwrap(); | |
vec.push(tensor.tractify().unwrap()) | |
} | |
vec | |
} | |
pub fn load_image<P: AsRef<path::Path>>(p: P) -> Array4<f32> { | |
let image = ::image::open(&p).unwrap().to_rgb(); | |
Array4::from_shape_fn((1, 3, 224, 224), |(_, c, y, x)| { | |
image[(x as u32, y as u32)][c] as f32 | |
}) | |
} | |
pub fn normalize_images(images: &mut ArrayViewMut4<f32>) { | |
// https://mxnet.incubator.apache.org/api/python/gluon/data.html#mxnet.gluon.data.vision.transforms.Normalize | |
let means = [0.485, 0.456, 0.406]; | |
let stds = [0.229, 0.224, 0.225]; | |
for mut image in images.outer_iter_mut() { | |
for (ix, mut chan) in image.outer_iter_mut().enumerate() { | |
chan /= 255.0; | |
chan -= means[ix]; | |
chan /= stds[ix]; | |
} | |
} | |
} | |
pub fn load_labels() -> Vec<String> { | |
use std::io::BufRead; | |
io::BufReader::new(fs::File::open("labels.txt").unwrap()) | |
.lines() | |
.collect::<::std::io::Result<Vec<String>>>() | |
.unwrap() | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment