Skip to content

Instantly share code, notes, and snippets.

@Alxandr
Created March 25, 2024 22:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Alxandr/0e574150b7eb51c54c6d1b77bce3d179 to your computer and use it in GitHub Desktop.
Save Alxandr/0e574150b7eb51c54c6d1b77bce3d179 to your computer and use it in GitHub Desktop.
use rkyv::{
bytecheck::{CheckBytes, Verify},
out_field,
rancor::{Error, Fallible, Strategy},
ser::{Positional, Writer, WriterExt},
validation::{validators::DefaultValidator, ArchiveContext, ArchiveContextExt},
Archive, Portable, RawRelPtr, Serialize,
};
use std::mem;
#[derive(thiserror::Error, Debug)]
pub enum ArchivedAnyError {
#[error("out of bounds")]
OutOfBounds,
}
#[derive(CheckBytes)]
#[check_bytes(crate = "::rkyv::bytecheck")]
#[repr(transparent)]
pub struct ArchivedAny {
ptr: RawRelPtr,
}
impl ArchivedAny {
pub unsafe fn resolve(pos: usize, resolver: AnyResolver, out: *mut Self) {
let (fp, fo) = out_field!(out.ptr);
RawRelPtr::emplace(pos + fp, resolver.pos, fo);
}
}
pub struct AnyResolver {
pos: usize,
}
impl ArchivedAny {
pub fn is_contained_in_buffer(&self, buffer: &[u8]) -> bool {
let buffer_ptr = buffer.as_ptr() as usize;
let buffer_range = buffer_ptr..buffer_ptr + buffer.len();
let self_ptr = self as *const Self as usize;
let self_range = self_ptr..self_ptr + mem::size_of_val(self);
buffer_range.contains(&self_range.start) && buffer_range.contains(&self_range.end)
}
pub fn downcast<T: Portable, E>(&self, buffer: &[u8]) -> Result<&ArchivedDowncasted<T>, E>
where
T: CheckBytes<Strategy<DefaultValidator, E>>,
E: Error,
{
// first check that the buffer contains all of self
if !self.is_contained_in_buffer(buffer) {
return Err(E::new(ArchivedAnyError::OutOfBounds));
}
let mut validator = DefaultValidator::new(buffer);
self.downcast_with_context(Strategy::<DefaultValidator, E>::wrap(&mut validator))
}
pub fn downcast_with_context<T: Portable, C>(
&self,
context: &mut C,
) -> Result<&ArchivedDowncasted<T>, C::Error>
where
T: CheckBytes<C>,
C: ArchiveContext + Fallible + ?Sized,
C::Error: Error,
{
let self_ptr = self as *const Self;
let casted_ptr = self_ptr as *const ArchivedDowncasted<T>;
// SAFETY: ArchivedDowncasted<T> is repr(transparent) over ArchivedAny, so any pointer to ArchivedAny
// is aligned and sized correctly to be casted to any ArchivedDowncasted<T>.
let result =
unsafe { <ArchivedDowncasted<T> as CheckBytes<C>>::check_bytes(casted_ptr, context) };
match result {
Err(e) => Err(e),
Ok(()) => Ok(unsafe { &*casted_ptr }),
}
}
}
#[derive(CheckBytes)]
#[check_bytes(verify, crate = "::rkyv::bytecheck")]
#[repr(transparent)]
pub struct ArchivedDowncasted<T> {
any: ArchivedAny,
phantom: std::marker::PhantomData<T>,
}
impl<T> AsRef<T> for ArchivedDowncasted<T> {
fn as_ref(&self) -> &T {
unsafe { &*(self.any.ptr.as_ptr() as *const T) }
}
}
unsafe impl Portable for ArchivedAny {}
unsafe impl<T: Portable> Portable for ArchivedDowncasted<T> {}
unsafe impl<C, T> Verify<C> for ArchivedDowncasted<T>
where
T: CheckBytes<C>,
C: ArchiveContext + Fallible + ?Sized,
C::Error: Error,
{
fn verify(&self, context: &mut C) -> Result<(), <C as Fallible>::Error> {
let base = self.any.ptr.base();
let offset = self.any.ptr.offset();
let ptr = unsafe { context.bounds_check_subtree_base_offset::<T>(base, offset, ())? };
let range = unsafe { context.push_prefix_subtree(ptr)? };
unsafe {
T::check_bytes(ptr, context)?;
}
unsafe {
context.pop_subtree_range(range)?;
}
Ok(())
}
}
pub trait SerializeAnyExt<S: Fallible + ?Sized>: Serialize<S> + Sized
where
S: Positional + Writer,
{
fn serialize_as_any(&self, serializer: &mut S) -> Result<AnyResolver, <S as Fallible>::Error>;
}
impl<T: Archive + Serialize<S>, S: Fallible + ?Sized> SerializeAnyExt<S> for T
where
S: Positional + Writer,
{
fn serialize_as_any(&self, serializer: &mut S) -> Result<AnyResolver, <S as Fallible>::Error> {
let resolver = self.serialize(serializer)?;
serializer.align_for::<T>()?;
let pos = unsafe { serializer.resolve_aligned(self, resolver)? };
Ok(AnyResolver { pos })
}
}
#[cfg(test)]
mod tests {
use super::*;
use rkyv::{rancor::BoxedError, ser::AllocSerializer, Deserialize, Serialize};
#[derive(Archive, Serialize, Deserialize)]
#[archive(check_bytes)]
struct InnerStruct {
inline: usize,
referenced: String,
}
#[derive(Archive, Serialize, Deserialize)]
#[archive(check_bytes)]
struct OuterStruct {
before_value: Mixed,
inner: Box<InnerStruct>,
after_value: Mixed,
}
#[derive(Archive, Serialize, Deserialize)]
#[archive(check_bytes)]
struct Mixed {
inline: usize,
referenced: String,
}
struct Typed {
before_value: Mixed,
any: OuterStruct,
after_value: Mixed,
}
struct TypedResolver {
before_value: MixedResolver,
any: AnyResolver,
after_value: MixedResolver,
}
#[derive(CheckBytes)]
#[check_bytes(crate = "::rkyv::bytecheck")]
#[repr(C)]
struct ArchivedUntyped {
before_value: ArchivedMixed,
any: ArchivedAny,
after_value: ArchivedMixed,
}
#[derive(CheckBytes)]
#[check_bytes(crate = "::rkyv::bytecheck")]
#[repr(C)]
struct Downcasted {
before_value: ArchivedMixed,
any: ArchivedDowncasted<ArchivedOuterStruct>,
after_value: ArchivedMixed,
}
unsafe impl Portable for ArchivedUntyped {}
unsafe impl Portable for Downcasted {}
impl Archive for Typed {
type Archived = ArchivedUntyped;
type Resolver = TypedResolver;
unsafe fn resolve(&self, pos: usize, resolver: Self::Resolver, out: *mut Self::Archived) {
// before_value
let (fp, fo) = rkyv::out_field!(out.before_value);
Archive::resolve(&self.before_value, pos + fp, resolver.before_value, fo);
// any
let (fp, fo) = rkyv::out_field!(out.any);
ArchivedAny::resolve(pos + fp, resolver.any, fo);
// after_value
let (fp, fo) = rkyv::out_field!(out.after_value);
Archive::resolve(&self.after_value, pos + fp, resolver.after_value, fo);
}
}
impl<S: Fallible + ?Sized> Serialize<S> for Typed
where
S: Positional + Writer,
String: Serialize<S>,
OuterStruct: Serialize<S>,
{
fn serialize(&self, serializer: &mut S) -> Result<Self::Resolver, <S as Fallible>::Error> {
let before_value = self.before_value.serialize(serializer)?;
let any = self.any.serialize_as_any(serializer)?;
let after_value = self.after_value.serialize(serializer)?;
Ok(TypedResolver {
before_value,
any,
after_value,
})
}
}
impl ArchivedUntyped {
fn downcast(&self, buffer: &[u8]) -> Result<&Downcasted, BoxedError> {
match self.any.downcast::<ArchivedOuterStruct, BoxedError>(buffer) {
Err(e) => Err(e),
Ok(_) => Ok(unsafe { &*(self as *const Self as *const Downcasted) }),
}
}
}
#[test]
fn downcasting() {
let value = Typed {
before_value: Mixed {
inline: 0,
referenced: "before".to_string(),
},
any: OuterStruct {
before_value: Mixed {
inline: 1,
referenced: "before".to_string(),
},
inner: Box::new(InnerStruct {
inline: 2,
referenced: "inner".to_string(),
}),
after_value: Mixed {
inline: 3,
referenced: "after".to_string(),
},
},
after_value: Mixed {
inline: 4,
referenced: "after".to_string(),
},
};
let bytes = rkyv::to_bytes::<Typed, 4096, BoxedError>(&value).expect("failed to archive");
let archived =
rkyv::access::<ArchivedUntyped, BoxedError>(&bytes).expect("valid bytes for untyped");
let downcasted = archived
.downcast(&bytes)
.expect("valid bytes for downcasted");
assert_eq!(downcasted.before_value.inline, 0);
assert_eq!(downcasted.before_value.referenced, "before");
assert_eq!(downcasted.after_value.inline, 4);
assert_eq!(downcasted.after_value.referenced, "after");
let inner = downcasted.any.as_ref();
assert_eq!(inner.before_value.inline, 1);
assert_eq!(inner.before_value.referenced, "before");
assert_eq!(inner.inner.inline, 2);
assert_eq!(inner.inner.referenced, "inner");
assert_eq!(inner.after_value.inline, 3);
assert_eq!(inner.after_value.referenced, "after");
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment