Skip to content

Instantly share code, notes, and snippets.

@Geal
Created February 13, 2019 17:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Geal/f486b4b0b8339fda31db3b7174335734 to your computer and use it in GitHub Desktop.
Save Geal/f486b4b0b8339fda31db3b7174335734 to your computer and use it in GitHub Desktop.
//extern crate conform;
extern crate flate2;
extern crate image;
extern crate ndarray;
extern crate tar;
extern crate protobuf;
#[allow(unused_imports)]
#[macro_use]
extern crate tract_core;
use std::{fs, io, path};
use tract_core::*;
use tract_onnx::pb::TensorProto;
use tract_onnx::*;
const MODEL: &[u8] = include_bytes!("../squeezenet/model.onnx");
fn main() {
let mut model = tract_onnx::model::for_reader(MODEL).unwrap();
model = model.into_optimized().unwrap();
let plan = SimplePlan::new(&model).unwrap();
let labels = load_labels();
let (input, expected) = load_dataset("squeezenet/test_data_set_10");
println!("input: {:?}", input);
let computed = plan.run(input).unwrap();
println!("computed: {:?}", computed);
println!("expected: {:?}", expected);
let label_id = computed[0]
.to_array_view::<f32>()
.unwrap()
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(0u32.cmp(&1)))
.unwrap()
.0;
let label = &labels[label_id];
println!("computed label(id = {}): {}", label_id, label);
let img = load_image("dog.png");
println!("input dog: {:?}", img);
let computed = plan.run(tvec![img]).unwrap();
println!("output dog: {:?}", computed);
let label_id = computed[0]
.to_array_view::<f32>()
.unwrap()
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(0u32.cmp(&1)))
.unwrap()
.0;
let label = &labels[label_id];
println!("computed label(id = {}): {}", label_id, label);
}
pub fn load_dataset(path: &str) -> (TVec<Tensor>, TVec<Tensor>) {
(
load_half_dataset("input", path),
load_half_dataset("output", path),
)
}
pub fn load_half_dataset(prefix: &str, path: &str) -> TVec<Tensor> {
let mut vec = tvec!();
let len = fs::read_dir(path)
.map_err(|e| format!("accessing {:?}, {:?}", path, e))
.unwrap()
.filter(|d| {
d.as_ref()
.unwrap()
.file_name()
.to_str()
.unwrap()
.starts_with(prefix)
})
.count();
for i in 0..len {
let filename = format!("{}/{}_{}.pb", path, prefix, i);
let mut file = fs::File::open(filename)
.map_err(|e| format!("accessing {:?}, {:?}", path, e))
.unwrap();
let tensor: TensorProto = ::protobuf::parse_from_reader(&mut file).unwrap();
vec.push(tensor.tractify().unwrap())
}
vec
}
pub fn load_image<P: AsRef<path::Path>>(p: P) -> ::tract_core::Tensor {
let image = ::image::open(&p).unwrap().to_rgb();
let resized = ::image::imageops::resize(&image, 224, 224, ::image::FilterType::Triangle);
let image: ::tract_core::Tensor =
//::ndarray::Array4::from_shape_fn((1, 224, 224, 3), |(_, y, x, c)| {
::ndarray::Array4::from_shape_fn((1, 3, 224, 224), |(_, c, y, x)| {
resized[(x as _, y as _)][c] as f32 / 255.0
})
.into_dyn()
.into();
image
}
pub fn load_labels() -> Vec<String> {
use std::io::BufRead;
io::BufReader::new(fs::File::open("labels.txt").unwrap())
.lines()
.collect::<::std::io::Result<Vec<String>>>()
.unwrap()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment