Created
December 13, 2021 16:16
-
-
Save kali/9b34f24e438f8661ef89fdc7ed67fe4b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use tract_onnx::prelude::*; | |
fn init() { | |
let _ = env_logger::builder().is_test(true).try_init(); | |
} | |
const VOCAB: [char; 28] = [ | |
' ', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', | |
's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '\'', | |
]; | |
fn main() -> TractResult<()> { | |
init(); | |
let symbolic_length = Symbol::new('L'); | |
let symbolic_signal_length = tensor1(&[TDim::from(symbolic_length)]); | |
let max_length = 1024 * 1024; | |
let model = tract_onnx::onnx() | |
.model_for_path("stt_en_quartznet15x5.onnx")? | |
.with_input_fact(0, InferenceFact::dt_shape(f32::datum_type(), &[1, max_length]))? | |
.with_input_fact(1, symbolic_signal_length.clone().into())? | |
.into_optimized()? | |
.into_runnable()?; | |
let mut reader = hound::WavReader::open("test.wav")?; | |
assert_eq!(reader.spec().channels, 1, "Only mono audio is supported"); | |
assert_eq!(reader.spec().sample_rate, 16_000, "The model only works on 16kHz audio streams"); | |
let length = reader.len() as usize; | |
let mut signal = tract_ndarray::Array::<f32, _>::zeros((1, max_length)); | |
for (signal, f) in signal.iter_mut().zip(reader.samples::<i16>()) { | |
*signal = f32::from(f?) / f32::from(i16::MAX); | |
} | |
let mut state = SimpleState::new(&model)?; | |
state.session_state.resolved_symbols = | |
SymbolValues::default().with(symbolic_length, length.try_into()?); | |
let predictions = state.run(tvec!(signal.into(), symbolic_signal_length))?; | |
// We decode the prediction from logits | |
let logits = predictions.get(0).unwrap().as_slice::<f32>().unwrap(); | |
let mut text = String::new(); | |
let mut ends_with_blank = false; | |
for probabilities in logits.chunks(VOCAB.len() + 1) { | |
// The last value is the "blank" value | |
let (best, _) = probabilities | |
.iter() | |
.enumerate() | |
.reduce(|(c_a, v_a), (c_b, v_b)| if v_a >= v_b { (c_a, v_a) } else { (c_b, v_b) }) | |
.unwrap(); | |
if best == VOCAB.len() { | |
print!("_"); | |
ends_with_blank = true; | |
} else { | |
let c = VOCAB[best]; | |
if ends_with_blank || !text.ends_with(c) { | |
text.push(c); | |
} | |
ends_with_blank = false; | |
print!("{}", c); | |
} | |
} | |
println!("The predicted text from logits is \"{}\"", text); | |
Ok(()) | |
} | |
#[test] | |
fn length_is_known() -> TractResult<()> { | |
let length = 128; | |
let signal = tract_ndarray::Array2::<f32>::zeros([1, length]).into_tensor(); | |
let signal_length = tensor1(&[length as i32]); | |
let model = tract_onnx::onnx() | |
.model_for_path("stt_en_quartznet15x5.onnx")? | |
.with_input_fact(0, InferenceFact::dt_shape(f32::datum_type(), &[1, length]))? | |
.with_input_fact(1, signal_length.clone().into())? | |
.into_optimized()? | |
.into_runnable()?; | |
model.run(tvec!(signal, signal_length))?; | |
Ok(()) | |
} | |
#[test] | |
fn length_is_variable() -> TractResult<()> { | |
init(); | |
let length = Symbol::new('L'); | |
let signal = tract_ndarray::Array2::<f32>::zeros([2, 1024]).into_tensor(); | |
let signal_length = tensor1(&[TDim::from(length)]); | |
let model = tract_onnx::onnx() | |
.model_for_path("stt_en_quartznet15x5.onnx")? | |
.with_input_fact(0, InferenceFact::dt_shape(f32::datum_type(), &[2, 1024]))? | |
.with_input_fact(1, signal_length.clone().into())? | |
.into_optimized()? | |
.into_runnable()?; | |
let actual_length = 128; | |
let mut state = SimpleState::new(&model)?; | |
state.session_state.resolved_symbols = SymbolValues::default().with(length, actual_length); | |
state.run(tvec!(signal, signal_length))?; | |
Ok(()) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment