Skip to content

Instantly share code, notes, and snippets.

@kali
Created December 13, 2021 16:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kali/9b34f24e438f8661ef89fdc7ed67fe4b to your computer and use it in GitHub Desktop.
Save kali/9b34f24e438f8661ef89fdc7ed67fe4b to your computer and use it in GitHub Desktop.
use tract_onnx::prelude::*;
fn init() {
let _ = env_logger::builder().is_test(true).try_init();
}
const VOCAB: [char; 28] = [
' ', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r',
's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '\'',
];
fn main() -> TractResult<()> {
init();
let symbolic_length = Symbol::new('L');
let symbolic_signal_length = tensor1(&[TDim::from(symbolic_length)]);
let max_length = 1024 * 1024;
let model = tract_onnx::onnx()
.model_for_path("stt_en_quartznet15x5.onnx")?
.with_input_fact(0, InferenceFact::dt_shape(f32::datum_type(), &[1, max_length]))?
.with_input_fact(1, symbolic_signal_length.clone().into())?
.into_optimized()?
.into_runnable()?;
let mut reader = hound::WavReader::open("test.wav")?;
assert_eq!(reader.spec().channels, 1, "Only mono audio is supported");
assert_eq!(reader.spec().sample_rate, 16_000, "The model only works on 16kHz audio streams");
let length = reader.len() as usize;
let mut signal = tract_ndarray::Array::<f32, _>::zeros((1, max_length));
for (signal, f) in signal.iter_mut().zip(reader.samples::<i16>()) {
*signal = f32::from(f?) / f32::from(i16::MAX);
}
let mut state = SimpleState::new(&model)?;
state.session_state.resolved_symbols =
SymbolValues::default().with(symbolic_length, length.try_into()?);
let predictions = state.run(tvec!(signal.into(), symbolic_signal_length))?;
// We decode the prediction from logits
let logits = predictions.get(0).unwrap().as_slice::<f32>().unwrap();
let mut text = String::new();
let mut ends_with_blank = false;
for probabilities in logits.chunks(VOCAB.len() + 1) {
// The last value is the "blank" value
let (best, _) = probabilities
.iter()
.enumerate()
.reduce(|(c_a, v_a), (c_b, v_b)| if v_a >= v_b { (c_a, v_a) } else { (c_b, v_b) })
.unwrap();
if best == VOCAB.len() {
print!("_");
ends_with_blank = true;
} else {
let c = VOCAB[best];
if ends_with_blank || !text.ends_with(c) {
text.push(c);
}
ends_with_blank = false;
print!("{}", c);
}
}
println!("The predicted text from logits is \"{}\"", text);
Ok(())
}
#[test]
fn length_is_known() -> TractResult<()> {
let length = 128;
let signal = tract_ndarray::Array2::<f32>::zeros([1, length]).into_tensor();
let signal_length = tensor1(&[length as i32]);
let model = tract_onnx::onnx()
.model_for_path("stt_en_quartznet15x5.onnx")?
.with_input_fact(0, InferenceFact::dt_shape(f32::datum_type(), &[1, length]))?
.with_input_fact(1, signal_length.clone().into())?
.into_optimized()?
.into_runnable()?;
model.run(tvec!(signal, signal_length))?;
Ok(())
}
#[test]
fn length_is_variable() -> TractResult<()> {
init();
let length = Symbol::new('L');
let signal = tract_ndarray::Array2::<f32>::zeros([2, 1024]).into_tensor();
let signal_length = tensor1(&[TDim::from(length)]);
let model = tract_onnx::onnx()
.model_for_path("stt_en_quartznet15x5.onnx")?
.with_input_fact(0, InferenceFact::dt_shape(f32::datum_type(), &[2, 1024]))?
.with_input_fact(1, signal_length.clone().into())?
.into_optimized()?
.into_runnable()?;
let actual_length = 128;
let mut state = SimpleState::new(&model)?;
state.session_state.resolved_symbols = SymbolValues::default().with(length, actual_length);
state.run(tvec!(signal, signal_length))?;
Ok(())
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment