Skip to content

Instantly share code, notes, and snippets.

@Alxandr
Created June 3, 2024 20:39
Show Gist options
  • Save Alxandr/5b24500c7ce05f34af6ec27c7f98973a to your computer and use it in GitHub Desktop.
Save Alxandr/5b24500c7ce05f34af6ec27c7f98973a to your computer and use it in GitHub Desktop.
rkyv-any
use rkyv::{
bytecheck::{CheckBytes, Verify},
out_field,
rancor::{Error, Fallible, Strategy},
ser::{Positional, Writer, WriterExt},
validation::{validators::DefaultValidator, ArchiveContext, ArchiveContextExt},
with::{ArchiveWith, SerializeWith},
Archive, Portable, RawRelPtr, Serialize,
};
#[derive(thiserror::Error, Debug)]
pub enum ArchivedAnyError {
#[error("out of bounds")]
OutOfBounds,
}
#[derive(CheckBytes)]
#[check_bytes(crate = "::rkyv::bytecheck")]
#[repr(transparent)]
pub struct ArchivedAny {
ptr: RawRelPtr,
}
impl ArchivedAny {
#[inline]
pub unsafe fn resolve(pos: usize, resolver: AnyResolver, out: *mut Self) {
let (fp, fo) = out_field!(out.ptr);
RawRelPtr::emplace(pos + fp, resolver.pos, fo);
}
}
pub struct AnyResolver {
pos: usize,
}
impl ArchivedAny {
pub fn downcast<'a, T: Portable, E>(
&self,
buffer: &'a [u8],
) -> Result<&'a ArchivedDowncasted<T>, E>
where
T: CheckBytes<Strategy<DefaultValidator, E>>,
E: Error,
{
let mut validator = DefaultValidator::new(buffer);
self.downcast_with_context(buffer, &mut validator)
}
pub fn downcast_with_context<'a, T: Portable, C, E>(
&self,
buffer: &'a [u8],
context: &mut C,
) -> Result<&'a ArchivedDowncasted<T>, E>
where
T: CheckBytes<Strategy<C, E>>,
C: ArchiveContext<E>,
E: Error,
{
let pos = (self as *const Self as usize)
.checked_sub(buffer.as_ptr() as usize)
.ok_or_else(|| Error::new(ArchivedAnyError::OutOfBounds))?;
rkyv::validation::util::access_pos_with_context(buffer, pos, context)
}
}
#[derive(CheckBytes)]
#[check_bytes(verify, crate = "::rkyv::bytecheck")]
#[repr(transparent)]
pub struct ArchivedDowncasted<T> {
any: ArchivedAny,
phantom: std::marker::PhantomData<T>,
}
impl<T: Portable> ArchivedDowncasted<T> {
#[inline]
pub unsafe fn resolve(pos: usize, resolver: AnyResolver, out: *mut Self) {
let out = out as *mut ArchivedAny;
ArchivedAny::resolve(pos, resolver, out);
}
}
impl<T> AsRef<T> for ArchivedDowncasted<T> {
fn as_ref(&self) -> &T {
unsafe { &*(self.any.ptr.as_ptr() as *const T) }
}
}
unsafe impl Portable for ArchivedAny {}
unsafe impl<T: Portable> Portable for ArchivedDowncasted<T> {}
unsafe impl<C, T> Verify<C> for ArchivedDowncasted<T>
where
T: CheckBytes<C>,
C: ArchiveContext + Fallible + ?Sized,
C::Error: Error,
{
fn verify(&self, context: &mut C) -> Result<(), <C as Fallible>::Error> {
let base = self.any.ptr.base();
let offset = self.any.ptr.offset();
let ptr = unsafe { context.bounds_check_subtree_base_offset::<T>(base, offset, ())? };
let range = unsafe { context.push_prefix_subtree(ptr)? };
unsafe {
T::check_bytes(ptr, context)?;
}
unsafe {
context.pop_subtree_range(range)?;
}
Ok(())
}
}
pub trait SerializeAnyExt<S: Fallible + ?Sized>: Serialize<S> + Sized
where
S: Positional + Writer,
{
fn serialize_as_any(&self, serializer: &mut S) -> Result<AnyResolver, <S as Fallible>::Error>;
}
impl<T: Archive + Serialize<S>, S: Fallible + ?Sized> SerializeAnyExt<S> for T
where
S: Positional + Writer,
{
fn serialize_as_any(&self, serializer: &mut S) -> Result<AnyResolver, <S as Fallible>::Error> {
let resolver = self.serialize(serializer)?;
serializer.align_for::<T>()?;
let pos = unsafe { serializer.resolve_aligned(self, resolver)? };
Ok(AnyResolver { pos })
}
}
/// A wrapper that serializes a type as `ArchivedAny`.
#[derive(Debug)]
pub struct AsAny;
impl<T: Archive> ArchiveWith<T> for AsAny {
type Archived = ArchivedAny;
type Resolver = AnyResolver;
unsafe fn resolve_with(
_field: &T,
pos: usize,
resolver: Self::Resolver,
out: *mut Self::Archived,
) {
ArchivedAny::resolve(pos, resolver, out);
}
}
impl<T: Archive, S: Fallible + ?Sized> SerializeWith<T, S> for AsAny
where
S: Writer + Positional,
S::Error: Error,
T: Serialize<S>,
{
fn serialize_with(
field: &T,
serializer: &mut S,
) -> Result<Self::Resolver, <S as Fallible>::Error> {
field.serialize_as_any(serializer)
}
}
#[cfg(test)]
mod tests {
use super::*;
use rkyv::{rancor::BoxedError, Deserialize, Serialize};
#[derive(Archive, Serialize, Deserialize)]
#[archive(check_bytes)]
struct InnerStruct {
inline: usize,
referenced: String,
}
#[derive(Archive, Serialize, Deserialize)]
#[archive(check_bytes)]
struct OuterStruct {
before_value: Mixed,
inner: Box<InnerStruct>,
after_value: Mixed,
}
#[derive(Archive, Serialize, Deserialize)]
#[archive(check_bytes)]
struct Mixed {
inline: usize,
referenced: String,
}
#[derive(Archive, Serialize)]
#[archive(as = "ArchivedUntyped")]
struct Typed {
before_value: Mixed,
#[with(AsAny)]
any: OuterStruct,
after_value: Mixed,
}
#[derive(CheckBytes)]
#[check_bytes(crate = "::rkyv::bytecheck")]
#[repr(C)]
struct ArchivedUntyped {
before_value: ArchivedMixed,
any: ArchivedAny,
after_value: ArchivedMixed,
}
#[derive(CheckBytes)]
#[check_bytes(crate = "::rkyv::bytecheck")]
#[repr(C)]
struct Downcasted {
before_value: ArchivedMixed,
any: ArchivedDowncasted<ArchivedOuterStruct>,
after_value: ArchivedMixed,
}
unsafe impl Portable for ArchivedUntyped {}
unsafe impl Portable for Downcasted {}
impl ArchivedUntyped {
fn downcast(&self, buffer: &[u8]) -> Result<&Downcasted, BoxedError> {
match self.any.downcast::<ArchivedOuterStruct, BoxedError>(buffer) {
Err(e) => Err(e),
Ok(_) => Ok(unsafe { &*(self as *const Self as *const Downcasted) }),
}
}
}
#[test]
fn downcasting() {
let value = Typed {
before_value: Mixed {
inline: 0,
referenced: "before".to_string(),
},
any: OuterStruct {
before_value: Mixed {
inline: 1,
referenced: "before".to_string(),
},
inner: Box::new(InnerStruct {
inline: 2,
referenced: "inner".to_string(),
}),
after_value: Mixed {
inline: 3,
referenced: "after".to_string(),
},
},
after_value: Mixed {
inline: 4,
referenced: "after".to_string(),
},
};
let bytes = rkyv::to_bytes::<Typed, 4096, BoxedError>(&value).expect("failed to archive");
let archived =
rkyv::access::<ArchivedUntyped, BoxedError>(&bytes).expect("valid bytes for untyped");
let downcasted = archived
.downcast(&bytes)
.expect("valid bytes for downcasted");
assert_eq!(downcasted.before_value.inline, 0);
assert_eq!(downcasted.before_value.referenced, "before");
assert_eq!(downcasted.after_value.inline, 4);
assert_eq!(downcasted.after_value.referenced, "after");
let inner = downcasted.any.as_ref();
assert_eq!(inner.before_value.inline, 1);
assert_eq!(inner.before_value.referenced, "before");
assert_eq!(inner.inner.inline, 2);
assert_eq!(inner.inner.referenced, "inner");
assert_eq!(inner.after_value.inline, 3);
assert_eq!(inner.after_value.referenced, "after");
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment