Created
June 14, 2022 21:25
-
-
Save etemesi254/e2cc65afbdf49b9b402506a9c9ff6802 to your computer and use it in GitHub Desktop.
MCU decoding without unsafe
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//! Implements routines to decode a MCU | |
//! | |
//! # Side notes | |
//! Yes, I pull in some dubious tricks, like really dubious here, they're not hard to come up | |
//! but I know they're hard to understand(e.g how I don't allocate space for Cb and Cr | |
//! channels if output colorspace is grayscale) but bear with me, it's the search for fast software | |
//! that got me here. | |
//! | |
//! # Multithreading | |
//! | |
//!This isn't exposed so I can dump all the info here | |
//! | |
//! To make multithreading work, we want to break dependency chains but in cool ways. | |
//! i.e we want to find out where we can forward one section as another one does something. | |
//! | |
//! # The algorithm | |
//! Simply do it per MCU width taking into account sub-sampling ratios | |
//! | |
//! 1. Decode an MCU width taking into account how many image channels we have(either Y only or Y,Cb and Cr) | |
//! | |
//! 2. After successfully decoding, copy pixels decoded and spawn a thread to handle post processing(IDCT, | |
//! upsampling and color conversion) | |
//! | |
//! 3. After successfully decoding all pixels, join threads. | |
//! | |
//! 4. Call it a day, | |
//! | |
//!But as easy as this sounds in theory, in practice, it sucks... | |
//! | |
//! We essentially have to consider that down-sampled images have weird MCU arrangement and for such cases | |
//! ! choose the path of decoding 2 whole MCU heights for horizontal/vertical upsampling and | |
//! 4 whole MCU heights for horizontal and vertical upsampling, which when expressed in code doesn't look nice. | |
//! | |
//! There is also the overhead of synchronization which makes some things annoying. | |
//! | |
//! Also there is the overhead of `cloning` and allocating intermediate memory to ensure multithreading is safe. | |
//! This may make this library almost 3X slower if someone chooses to disable `threadpool` (please don't) feature because | |
//! we are optimized for the multithreading path. | |
//! | |
//! # Scoped ThreadPools | |
//! Things you don't want to do in the fast path. **Lock da mutex** | |
//! Things you don't want to have in your code. **Mutex** | |
//! | |
//! Multithreading is not everyone's cake because synchronization is like battling with the devil | |
//! The default way is a mutex for when threads may write to the same memory location. But in our case we | |
//! don't write to the same, location, so why pay for something not used. | |
//! | |
//! In C/C++ land we can just pass mutable chunks to different threads but in Rust don't you know about | |
//! the borrow checker?... | |
//! | |
//! To send different mutable chunks to threads, we use scoped threads which guarantee that the thread | |
//! won't outlive the data and finally let it compile. | |
//! This allows us to not use locks during decoding avoiding that overhead. and allowing more cleaner | |
//! faster code in post processing.. | |
use std::cmp::min; | |
use std::io::Cursor; | |
use std::sync::Arc; | |
use crate::bitstream::BitStream; | |
use crate::components::{ComponentID, SubSampRatios}; | |
use crate::errors::DecodeErrors; | |
use crate::marker::Marker; | |
use crate::worker::post_process; | |
use crate::Decoder; | |
/// The size of a DC block for a MCU. | |
pub const DCT_BLOCK: usize = 64; | |
impl Decoder | |
{ | |
/// Check for existence of DC and AC Huffman Tables | |
fn check_tables(&self) -> Result<(), DecodeErrors> | |
{ | |
// check that dc and AC tables exist outside the hot path | |
for i in 0..self.input_colorspace.num_components() | |
{ | |
let _ = &self | |
.dc_huffman_tables | |
.get(self.components[i].dc_huff_table) | |
.as_ref() | |
.ok_or_else(|| { | |
DecodeErrors::HuffmanDecode(format!( | |
"No Huffman DC table for component {:?} ", | |
self.components[i].component_id | |
)) | |
})? | |
.as_ref() | |
.ok_or_else(|| { | |
DecodeErrors::HuffmanDecode(format!( | |
"No DC table for component {:?}", | |
self.components[i].component_id | |
)) | |
})?; | |
let _ = &self | |
.ac_huffman_tables | |
.get(self.components[i].ac_huff_table) | |
.as_ref() | |
.ok_or_else(|| { | |
DecodeErrors::HuffmanDecode(format!( | |
"No Huffman AC table for component {:?} ", | |
self.components[i].component_id | |
)) | |
})? | |
.as_ref() | |
.ok_or_else(|| { | |
DecodeErrors::HuffmanDecode(format!( | |
"No AC table for component {:?}", | |
self.components[i].component_id | |
)) | |
})?; | |
} | |
Ok(()) | |
} | |
/// Decode MCUs and carry out post processing. | |
/// | |
/// This is the main decoder loop for the library, the hot path. | |
/// | |
/// Because of this, we pull in some very crazy optimization tricks hence readability is a pinch | |
/// here. | |
#[allow(clippy::similar_names)] | |
#[inline(never)] | |
#[rustfmt::skip] | |
pub(crate) fn decode_mcu_ycbcr_baseline( | |
&mut self, reader: &mut Cursor<Vec<u8>>, | |
) -> Result<Vec<u8>, DecodeErrors> | |
{ | |
self.check_component_dimensions()?; | |
let mut scoped_pools = scoped_threadpool::Pool::new(self.num_threads.unwrap_or( num_cpus::get()) as u32); | |
info!("Created {} worker threads", scoped_pools.thread_count()); | |
let (mcu_width, mcu_height); | |
let mut bias = 1; | |
if self.interleaved | |
{ | |
// set upsampling functions | |
self.set_upsampling()?; | |
if self.sub_sample_ratio == SubSampRatios::H | |
{ | |
// horizontal sub-sampling. | |
// Values for horizontal samples end halfway the image and do not complete an MCU width. | |
// To make it complete we multiply width by 2 and divide mcu_height by 2 | |
mcu_width = self.mcu_x * 2; | |
mcu_height = self.mcu_y / 2; | |
} else if self.sub_sample_ratio == SubSampRatios::HV | |
{ | |
mcu_width = self.mcu_x; | |
mcu_height = self.mcu_y / 2; | |
bias = 2; | |
// V; | |
} else { | |
mcu_width = self.mcu_x; | |
mcu_height = self.mcu_y; | |
} | |
} else { | |
// For non-interleaved images( (1*1) subsampling) | |
// number of MCU's are the widths (+7 to account for paddings) divided bu 8. | |
mcu_width = ((self.info.width + 7) / 8) as usize; | |
mcu_height = ((self.info.height + 7) / 8) as usize; | |
} | |
let mut stream = BitStream::new(); | |
// Size of our output image(width*height) | |
let capacity = usize::from(self.info.width + 7) * usize::from(self.info.height + 7); | |
let component_capacity = mcu_width * DCT_BLOCK; | |
// Create an Arc of components to prevent cloning on every MCU width | |
let global_component = Arc::new(self.components.clone()); | |
// Storage for decoded pixels | |
let mut global_channel = vec![0; capacity * self.output_colorspace.num_components()]; | |
// things needed for post processing that we can remove out of the loop | |
let input = self.input_colorspace; | |
let output = self.output_colorspace; | |
let idct_func = self.idct_func; | |
let color_convert_16 = self.color_convert_16; | |
let width = usize::from(self.width()); | |
let h_max = self.h_max; | |
let v_max = self.v_max; | |
// Halfway width size, used for vertical sub-sampling to write |Y2| in the right position. | |
let width_stride = (component_capacity * self.components[0].vertical_sample * self.components[0].horizontal_sample * bias) >> 1; | |
let hv_width_stride = width_stride >> 1; | |
// check dc and AC tables | |
self.check_tables()?; | |
let is_hv = self.sub_sample_ratio == SubSampRatios::HV; | |
// Split output into different blocks each containing enough space for an MCU width | |
let mut chunks = | |
global_channel.chunks_exact_mut(width * output.num_components() * 8 * h_max * v_max); | |
let mut tmp = [0; DCT_BLOCK]; | |
// Argument for scoped threadpools, see file docs. | |
scoped_pools.scoped::<_, Result<(), DecodeErrors>>(|scope| { | |
for _ in 0..mcu_height | |
{ | |
// faster to memset than a later memcpy | |
// We allocate on every mcu_height since this is sent to a separate | |
// thread (that's how we're multi-threaded and thread safe). | |
let mut temporary = [vec![], vec![], vec![]]; | |
for (pos, comp) in self.components.iter().enumerate() | |
{ | |
// multiply capacity with sampling factor, it should be 1*1 for un-sampled images | |
// Allocate only needed components. | |
if min(self.output_colorspace.num_components() - 1, pos) == pos | |
{ | |
let len = component_capacity * comp.vertical_sample * comp.horizontal_sample * bias; | |
temporary[pos] = vec![0; len]; | |
} | |
} | |
// Bias only affects 4:2:0(chroma quartered) sub-sampled images. | |
// since we want to fetch two MCU rows before we send it to post process | |
for v in 0..bias | |
{ | |
for j in 0..mcu_width | |
{ | |
// iterate over components | |
for pos in 0..self.input_colorspace.num_components() | |
{ | |
let component = &mut self.components[pos]; | |
// Safety:The tables were confirmed to exist in self.check_tables(); | |
// This should be kept as is. Checking these tables here was | |
let dc_table = self | |
.dc_huffman_tables | |
.get(component.dc_huff_table) | |
.ok_or_else(|| { | |
DecodeErrors::HuffmanDecode(format!( | |
"No Huffman DC table for component {:?} ", | |
component.component_id | |
)) | |
})? | |
.as_ref() | |
.ok_or_else(|| { | |
DecodeErrors::HuffmanDecode(format!( | |
"No DC table for component {:?}", | |
component.component_id | |
)) | |
})?; | |
let ac_table = self | |
.ac_huffman_tables | |
.get(component.ac_huff_table) | |
// .as_ref() | |
.ok_or_else(|| { | |
DecodeErrors::HuffmanDecode(format!( | |
"No Huffman AC table for component {:?} ", | |
component.component_id | |
)) | |
})? | |
.as_ref() | |
.ok_or_else(|| { | |
DecodeErrors::HuffmanDecode(format!( | |
"No AC table for component {:?}", | |
component.component_id | |
)) | |
})?; | |
// let dc_table = unsafe { | |
// self.dc_huffman_tables | |
// .get_unchecked(component.dc_huff_table) | |
// .as_ref() | |
// .unwrap_or_else(|| std::hint::unreachable_unchecked()) | |
// }; | |
// let ac_table = unsafe { | |
// self.ac_huffman_tables | |
// .get_unchecked(component.ac_huff_table) | |
// .as_ref() | |
// .unwrap_or_else(|| std::hint::unreachable_unchecked()) | |
// }; | |
// If image is interleaved iterate over scan components, | |
// otherwise if it-s non-interleaved, these routines iterate in | |
// trivial scanline order(Y,Cb,Cr) | |
for v_samp in 0..component.vertical_sample | |
{ | |
for h_samp in 0..component.horizontal_sample | |
{ | |
// only decode needed components | |
if min(self.output_colorspace.num_components() - 1, pos) == pos | |
{ | |
// The spec https://www.w3.org/Graphics/JPEG/itu-t81.pdf page 26 | |
// Get position to write | |
// This is complex, don't even try to understand it. ~author | |
let is_y = | |
usize::from(component.component_id == ComponentID::Y); | |
// This only affects 4:2:0 images. | |
let y_offset = is_y | |
* v | |
* (hv_width_stride | |
+ (hv_width_stride * (component.vertical_sample - 1))); | |
let another_stride = | |
(width_stride * v_samp * usize::from(!is_hv)) | |
+ hv_width_stride * v_samp * usize::from(is_hv); | |
let yet_another_stride = usize::from(is_hv) | |
* (width_stride >> 2) | |
* v | |
* usize::from(component.component_id != ComponentID::Y); | |
// offset calculator. | |
let start = (j * 64 * component.horizontal_sample) | |
+ (h_samp * 64) | |
+ another_stride | |
+ y_offset | |
+ yet_another_stride; | |
// Get the location we will be writing to. | |
// It will always be zero since it's initialized per MCU height. | |
let tmp: &mut [i16; 64] = temporary.get_mut(pos).unwrap().get_mut(start..start + 64).unwrap().try_into().unwrap(); | |
stream.decode_mcu_block(reader, dc_table, ac_table, tmp, &mut component.dc_pred)?; | |
} | |
else | |
{ | |
// component not needed, decode and discard bits | |
stream.decode_mcu_block(reader, dc_table, ac_table, &mut tmp, &mut component.dc_pred)?; | |
} | |
} | |
} | |
self.todo = self.todo.wrapping_sub(1); | |
// after every interleaved MCU that's a mcu, count down restart markers. | |
if self.todo == 0 | |
{ | |
self.handle_rst(&mut stream)?; | |
} | |
} | |
} | |
} | |
// Clone things, to make multithreading safe | |
let component = global_component.clone(); | |
let next_chunk = chunks.next().unwrap(); | |
scope.execute(move || { | |
post_process(&mut temporary, &component, | |
idct_func, color_convert_16, | |
input, output, next_chunk, | |
width); | |
}); | |
} | |
//everything is okay | |
Ok(()) | |
})?; | |
info!("Finished decoding image"); | |
// remove excess allocation for images. | |
global_channel.truncate( | |
usize::from(self.width()) | |
* usize::from(self.height()) | |
* self.output_colorspace.num_components(), | |
); | |
return Ok(global_channel); | |
} | |
// handle RST markers. | |
// No-op if not using restarts | |
// this routine is shared with mcu_prog | |
#[cold] | |
pub(crate) fn handle_rst(&mut self, stream: &mut BitStream) -> Result<(), DecodeErrors> | |
{ | |
self.todo = self.restart_interval; | |
if let Some(marker) = stream.marker | |
{ | |
// Found a marker | |
// Read stream and see what marker is stored there | |
match marker | |
{ | |
Marker::RST(_) => | |
{ | |
// reset stream | |
stream.reset(); | |
// Initialize dc predictions to zero for all components | |
self.components.iter_mut().for_each(|x| x.dc_pred = 0); | |
// Start iterating again. from position. | |
} | |
Marker::EOI => | |
{ | |
// silent pass | |
} | |
_ => | |
{ | |
return Err(DecodeErrors::MCUError(format!( | |
"Marker {:?} found in bitstream, possibly corrupt jpeg", | |
marker | |
))); | |
} | |
} | |
} | |
Ok(()) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment