Created
February 13, 2019 17:35
-
-
Save Geal/f486b4b0b8339fda31db3b7174335734 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//extern crate conform; | |
extern crate flate2; | |
extern crate image; | |
extern crate ndarray; | |
extern crate tar; | |
extern crate protobuf; | |
#[allow(unused_imports)] | |
#[macro_use] | |
extern crate tract_core; | |
use std::{fs, io, path}; | |
use tract_core::*; | |
use tract_onnx::pb::TensorProto; | |
use tract_onnx::*; | |
const MODEL: &[u8] = include_bytes!("../squeezenet/model.onnx"); | |
fn main() { | |
let mut model = tract_onnx::model::for_reader(MODEL).unwrap(); | |
model = model.into_optimized().unwrap(); | |
let plan = SimplePlan::new(&model).unwrap(); | |
let labels = load_labels(); | |
let (input, expected) = load_dataset("squeezenet/test_data_set_10"); | |
println!("input: {:?}", input); | |
let computed = plan.run(input).unwrap(); | |
println!("computed: {:?}", computed); | |
println!("expected: {:?}", expected); | |
let label_id = computed[0] | |
.to_array_view::<f32>() | |
.unwrap() | |
.iter() | |
.enumerate() | |
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(0u32.cmp(&1))) | |
.unwrap() | |
.0; | |
let label = &labels[label_id]; | |
println!("computed label(id = {}): {}", label_id, label); | |
let img = load_image("dog.png"); | |
println!("input dog: {:?}", img); | |
let computed = plan.run(tvec![img]).unwrap(); | |
println!("output dog: {:?}", computed); | |
let label_id = computed[0] | |
.to_array_view::<f32>() | |
.unwrap() | |
.iter() | |
.enumerate() | |
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(0u32.cmp(&1))) | |
.unwrap() | |
.0; | |
let label = &labels[label_id]; | |
println!("computed label(id = {}): {}", label_id, label); | |
} | |
pub fn load_dataset(path: &str) -> (TVec<Tensor>, TVec<Tensor>) { | |
( | |
load_half_dataset("input", path), | |
load_half_dataset("output", path), | |
) | |
} | |
pub fn load_half_dataset(prefix: &str, path: &str) -> TVec<Tensor> { | |
let mut vec = tvec!(); | |
let len = fs::read_dir(path) | |
.map_err(|e| format!("accessing {:?}, {:?}", path, e)) | |
.unwrap() | |
.filter(|d| { | |
d.as_ref() | |
.unwrap() | |
.file_name() | |
.to_str() | |
.unwrap() | |
.starts_with(prefix) | |
}) | |
.count(); | |
for i in 0..len { | |
let filename = format!("{}/{}_{}.pb", path, prefix, i); | |
let mut file = fs::File::open(filename) | |
.map_err(|e| format!("accessing {:?}, {:?}", path, e)) | |
.unwrap(); | |
let tensor: TensorProto = ::protobuf::parse_from_reader(&mut file).unwrap(); | |
vec.push(tensor.tractify().unwrap()) | |
} | |
vec | |
} | |
pub fn load_image<P: AsRef<path::Path>>(p: P) -> ::tract_core::Tensor { | |
let image = ::image::open(&p).unwrap().to_rgb(); | |
let resized = ::image::imageops::resize(&image, 224, 224, ::image::FilterType::Triangle); | |
let image: ::tract_core::Tensor = | |
//::ndarray::Array4::from_shape_fn((1, 224, 224, 3), |(_, y, x, c)| { | |
::ndarray::Array4::from_shape_fn((1, 3, 224, 224), |(_, c, y, x)| { | |
resized[(x as _, y as _)][c] as f32 / 255.0 | |
}) | |
.into_dyn() | |
.into(); | |
image | |
} | |
pub fn load_labels() -> Vec<String> { | |
use std::io::BufRead; | |
io::BufReader::new(fs::File::open("labels.txt").unwrap()) | |
.lines() | |
.collect::<::std::io::Result<Vec<String>>>() | |
.unwrap() | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment