Skip to content

Instantly share code, notes, and snippets.

@LoipesMas
Last active January 23, 2023 11:11
Show Gist options
  • Save LoipesMas/2d342b8087dbae4af31d8af2752e84de to your computer and use it in GitHub Desktop.
Save LoipesMas/2d342b8087dbae4af31d8af2752e84de to your computer and use it in GitHub Desktop.
Inference using ort
use std::{path::Path, sync::Arc};
use image::io::Reader as ImageReader;
use ndarray::{s, ArrayBase, Axis, Dim, IxDyn, OwnedRepr};
use nshare::ToNdarray3;
use ort::{
tensor::{DynOrtTensor, FromArray, InputTensor, OrtOwnedTensor},
Environment, OrtResult, SessionBuilder,
};
fn main() {
infere("path/to/image.png").unwrap();
}
fn open_image<P: AsRef<Path>>(path: P) -> ArrayBase<OwnedRepr<f32>, Dim<[usize; 3]>> {
let img = ImageReader::open(path).unwrap().decode().unwrap();
img.crop_imm(100, 100, 224, 224)
.into_rgb32f()
.into_ndarray3()
}
pub fn infere<P: AsRef<Path>>(file_path: P) -> OrtResult<()> {
let environment = Arc::new(Environment::builder().build()?);
let session = SessionBuilder::new(&environment)?
.with_model_from_file("./squeezenet1.0-13-qdq.onnx")?;
let image_arr = open_image(file_path);
let input = image_arr.insert_axis(Axis(0));
dbg!(&input.shape());
dbg!(input.slice(s![0, .., 100, 100]));
dbg!(input.slice(s![0, .., 180, 50]));
let outputs: Vec<DynOrtTensor<ndarray::Dim<ndarray::IxDynImpl>>> = session
.run([
InputTensor::from_array(input.into_dyn()),
// InputTensor::from_array(array![1120i64, 1120].into_dyn()),
])
.unwrap();
let scores = &outputs[0];
let scores: OrtOwnedTensor<'_, f32, IxDyn> = scores.try_extract()?;
let scores = scores.view();
let scores = scores.view();
let max_score = scores
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap());
dbg!(max_score);
dbg!(scores.slice(s![0, 322, .., ..]));
Ok(())
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment