Created
March 25, 2024 22:23
-
-
Save Alxandr/0e574150b7eb51c54c6d1b77bce3d179 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use rkyv::{ | |
bytecheck::{CheckBytes, Verify}, | |
out_field, | |
rancor::{Error, Fallible, Strategy}, | |
ser::{Positional, Writer, WriterExt}, | |
validation::{validators::DefaultValidator, ArchiveContext, ArchiveContextExt}, | |
Archive, Portable, RawRelPtr, Serialize, | |
}; | |
use std::mem; | |
#[derive(thiserror::Error, Debug)] | |
pub enum ArchivedAnyError { | |
#[error("out of bounds")] | |
OutOfBounds, | |
} | |
#[derive(CheckBytes)] | |
#[check_bytes(crate = "::rkyv::bytecheck")] | |
#[repr(transparent)] | |
pub struct ArchivedAny { | |
ptr: RawRelPtr, | |
} | |
impl ArchivedAny { | |
pub unsafe fn resolve(pos: usize, resolver: AnyResolver, out: *mut Self) { | |
let (fp, fo) = out_field!(out.ptr); | |
RawRelPtr::emplace(pos + fp, resolver.pos, fo); | |
} | |
} | |
pub struct AnyResolver { | |
pos: usize, | |
} | |
impl ArchivedAny { | |
pub fn is_contained_in_buffer(&self, buffer: &[u8]) -> bool { | |
let buffer_ptr = buffer.as_ptr() as usize; | |
let buffer_range = buffer_ptr..buffer_ptr + buffer.len(); | |
let self_ptr = self as *const Self as usize; | |
let self_range = self_ptr..self_ptr + mem::size_of_val(self); | |
buffer_range.contains(&self_range.start) && buffer_range.contains(&self_range.end) | |
} | |
pub fn downcast<T: Portable, E>(&self, buffer: &[u8]) -> Result<&ArchivedDowncasted<T>, E> | |
where | |
T: CheckBytes<Strategy<DefaultValidator, E>>, | |
E: Error, | |
{ | |
// first check that the buffer contains all of self | |
if !self.is_contained_in_buffer(buffer) { | |
return Err(E::new(ArchivedAnyError::OutOfBounds)); | |
} | |
let mut validator = DefaultValidator::new(buffer); | |
self.downcast_with_context(Strategy::<DefaultValidator, E>::wrap(&mut validator)) | |
} | |
pub fn downcast_with_context<T: Portable, C>( | |
&self, | |
context: &mut C, | |
) -> Result<&ArchivedDowncasted<T>, C::Error> | |
where | |
T: CheckBytes<C>, | |
C: ArchiveContext + Fallible + ?Sized, | |
C::Error: Error, | |
{ | |
let self_ptr = self as *const Self; | |
let casted_ptr = self_ptr as *const ArchivedDowncasted<T>; | |
// SAFETY: ArchivedDowncasted<T> is repr(transparent) over ArchivedAny, so any pointer to ArchivedAny | |
// is aligned and sized correctly to be casted to any ArchivedDowncasted<T>. | |
let result = | |
unsafe { <ArchivedDowncasted<T> as CheckBytes<C>>::check_bytes(casted_ptr, context) }; | |
match result { | |
Err(e) => Err(e), | |
Ok(()) => Ok(unsafe { &*casted_ptr }), | |
} | |
} | |
} | |
#[derive(CheckBytes)] | |
#[check_bytes(verify, crate = "::rkyv::bytecheck")] | |
#[repr(transparent)] | |
pub struct ArchivedDowncasted<T> { | |
any: ArchivedAny, | |
phantom: std::marker::PhantomData<T>, | |
} | |
impl<T> AsRef<T> for ArchivedDowncasted<T> { | |
fn as_ref(&self) -> &T { | |
unsafe { &*(self.any.ptr.as_ptr() as *const T) } | |
} | |
} | |
unsafe impl Portable for ArchivedAny {} | |
unsafe impl<T: Portable> Portable for ArchivedDowncasted<T> {} | |
unsafe impl<C, T> Verify<C> for ArchivedDowncasted<T> | |
where | |
T: CheckBytes<C>, | |
C: ArchiveContext + Fallible + ?Sized, | |
C::Error: Error, | |
{ | |
fn verify(&self, context: &mut C) -> Result<(), <C as Fallible>::Error> { | |
let base = self.any.ptr.base(); | |
let offset = self.any.ptr.offset(); | |
let ptr = unsafe { context.bounds_check_subtree_base_offset::<T>(base, offset, ())? }; | |
let range = unsafe { context.push_prefix_subtree(ptr)? }; | |
unsafe { | |
T::check_bytes(ptr, context)?; | |
} | |
unsafe { | |
context.pop_subtree_range(range)?; | |
} | |
Ok(()) | |
} | |
} | |
pub trait SerializeAnyExt<S: Fallible + ?Sized>: Serialize<S> + Sized | |
where | |
S: Positional + Writer, | |
{ | |
fn serialize_as_any(&self, serializer: &mut S) -> Result<AnyResolver, <S as Fallible>::Error>; | |
} | |
impl<T: Archive + Serialize<S>, S: Fallible + ?Sized> SerializeAnyExt<S> for T | |
where | |
S: Positional + Writer, | |
{ | |
fn serialize_as_any(&self, serializer: &mut S) -> Result<AnyResolver, <S as Fallible>::Error> { | |
let resolver = self.serialize(serializer)?; | |
serializer.align_for::<T>()?; | |
let pos = unsafe { serializer.resolve_aligned(self, resolver)? }; | |
Ok(AnyResolver { pos }) | |
} | |
} | |
#[cfg(test)] | |
mod tests { | |
use super::*; | |
use rkyv::{rancor::BoxedError, ser::AllocSerializer, Deserialize, Serialize}; | |
#[derive(Archive, Serialize, Deserialize)] | |
#[archive(check_bytes)] | |
struct InnerStruct { | |
inline: usize, | |
referenced: String, | |
} | |
#[derive(Archive, Serialize, Deserialize)] | |
#[archive(check_bytes)] | |
struct OuterStruct { | |
before_value: Mixed, | |
inner: Box<InnerStruct>, | |
after_value: Mixed, | |
} | |
#[derive(Archive, Serialize, Deserialize)] | |
#[archive(check_bytes)] | |
struct Mixed { | |
inline: usize, | |
referenced: String, | |
} | |
struct Typed { | |
before_value: Mixed, | |
any: OuterStruct, | |
after_value: Mixed, | |
} | |
struct TypedResolver { | |
before_value: MixedResolver, | |
any: AnyResolver, | |
after_value: MixedResolver, | |
} | |
#[derive(CheckBytes)] | |
#[check_bytes(crate = "::rkyv::bytecheck")] | |
#[repr(C)] | |
struct ArchivedUntyped { | |
before_value: ArchivedMixed, | |
any: ArchivedAny, | |
after_value: ArchivedMixed, | |
} | |
#[derive(CheckBytes)] | |
#[check_bytes(crate = "::rkyv::bytecheck")] | |
#[repr(C)] | |
struct Downcasted { | |
before_value: ArchivedMixed, | |
any: ArchivedDowncasted<ArchivedOuterStruct>, | |
after_value: ArchivedMixed, | |
} | |
unsafe impl Portable for ArchivedUntyped {} | |
unsafe impl Portable for Downcasted {} | |
impl Archive for Typed { | |
type Archived = ArchivedUntyped; | |
type Resolver = TypedResolver; | |
unsafe fn resolve(&self, pos: usize, resolver: Self::Resolver, out: *mut Self::Archived) { | |
// before_value | |
let (fp, fo) = rkyv::out_field!(out.before_value); | |
Archive::resolve(&self.before_value, pos + fp, resolver.before_value, fo); | |
// any | |
let (fp, fo) = rkyv::out_field!(out.any); | |
ArchivedAny::resolve(pos + fp, resolver.any, fo); | |
// after_value | |
let (fp, fo) = rkyv::out_field!(out.after_value); | |
Archive::resolve(&self.after_value, pos + fp, resolver.after_value, fo); | |
} | |
} | |
impl<S: Fallible + ?Sized> Serialize<S> for Typed | |
where | |
S: Positional + Writer, | |
String: Serialize<S>, | |
OuterStruct: Serialize<S>, | |
{ | |
fn serialize(&self, serializer: &mut S) -> Result<Self::Resolver, <S as Fallible>::Error> { | |
let before_value = self.before_value.serialize(serializer)?; | |
let any = self.any.serialize_as_any(serializer)?; | |
let after_value = self.after_value.serialize(serializer)?; | |
Ok(TypedResolver { | |
before_value, | |
any, | |
after_value, | |
}) | |
} | |
} | |
impl ArchivedUntyped { | |
fn downcast(&self, buffer: &[u8]) -> Result<&Downcasted, BoxedError> { | |
match self.any.downcast::<ArchivedOuterStruct, BoxedError>(buffer) { | |
Err(e) => Err(e), | |
Ok(_) => Ok(unsafe { &*(self as *const Self as *const Downcasted) }), | |
} | |
} | |
} | |
#[test] | |
fn downcasting() { | |
let value = Typed { | |
before_value: Mixed { | |
inline: 0, | |
referenced: "before".to_string(), | |
}, | |
any: OuterStruct { | |
before_value: Mixed { | |
inline: 1, | |
referenced: "before".to_string(), | |
}, | |
inner: Box::new(InnerStruct { | |
inline: 2, | |
referenced: "inner".to_string(), | |
}), | |
after_value: Mixed { | |
inline: 3, | |
referenced: "after".to_string(), | |
}, | |
}, | |
after_value: Mixed { | |
inline: 4, | |
referenced: "after".to_string(), | |
}, | |
}; | |
let bytes = rkyv::to_bytes::<Typed, 4096, BoxedError>(&value).expect("failed to archive"); | |
let archived = | |
rkyv::access::<ArchivedUntyped, BoxedError>(&bytes).expect("valid bytes for untyped"); | |
let downcasted = archived | |
.downcast(&bytes) | |
.expect("valid bytes for downcasted"); | |
assert_eq!(downcasted.before_value.inline, 0); | |
assert_eq!(downcasted.before_value.referenced, "before"); | |
assert_eq!(downcasted.after_value.inline, 4); | |
assert_eq!(downcasted.after_value.referenced, "after"); | |
let inner = downcasted.any.as_ref(); | |
assert_eq!(inner.before_value.inline, 1); | |
assert_eq!(inner.before_value.referenced, "before"); | |
assert_eq!(inner.inner.inline, 2); | |
assert_eq!(inner.inner.referenced, "inner"); | |
assert_eq!(inner.after_value.inline, 3); | |
assert_eq!(inner.after_value.referenced, "after"); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment