Skip to content

Instantly share code, notes, and snippets.

@dtolnay
Last active July 28, 2018 21:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dtolnay/dd05bc98c438e160d142fbbc0272d92c to your computer and use it in GitHub Desktop.
Save dtolnay/dd05bc98c438e160d142fbbc0272d92c to your computer and use it in GitHub Desktop.
Rc<> DAG serialization/deserialization proof of concept
//! ```toml
//! [dependencies]
//! serde = "0.9"
//! serde_json = "0.9"
//! bincode = "=1.0.0-alpha5"
//! ```
extern crate serde;
extern crate serde_json;
extern crate bincode;
use std::cell::RefCell;
use std::collections::HashMap;
use std::fmt;
use std::rc::Rc;
use serde::ser::{Serialize, Serializer, SerializeStruct, SerializeStructVariant};
use serde::de::{Deserialize, DeserializeSeed, Deserializer, Visitor, SeqVisitor,
MapVisitor, EnumVisitor, VariantVisitor, Error, Unexpected};
use bincode::SizeLimit;
//////////////////////////////////////////////////////////////////////////////
#[derive(Debug)]
struct Node {
data: char,
left: Option<Rc<Node>>,
right: Option<Rc<Node>>,
}
/// ```
/// A
/// / \
/// ( B
/// \ / \
/// C )
/// / \ /
/// D E
/// ```
fn example() -> Node {
let e = Rc::new(Node { data: 'E', left: None, right: None });
let d = Rc::new(Node { data: 'D', left: None, right: None });
let c = Rc::new(Node { data: 'C', left: Some(d), right: Some(e.clone()) });
let b = Rc::new(Node { data: 'B', left: Some(c.clone()), right: Some(e) });
Node { data: 'A', left: Some(c), right: Some(b) }
}
fn check(a: &Node) {
let b = a.right.as_ref().unwrap();
let c = a.left.as_ref().unwrap();
let d = c.left.as_ref().unwrap();
let e = c.right.as_ref().unwrap();
assert_eq!('A', a.data);
assert_eq!('B', b.data);
assert_eq!('C', c.data);
assert_eq!('D', d.data);
assert_eq!('E', e.data);
assert_eq!(&**c as *const Node, &**b.left.as_ref().unwrap() as *const Node);
assert_eq!(&**e as *const Node, &**b.right.as_ref().unwrap() as *const Node);
}
fn main() {
let root = example();
check(&root);
let j = serde_json::to_string_pretty(&root).unwrap();
check(&serde_json::from_str(&j).unwrap());
println!("{}", j);
let b = bincode::serialize(&root, SizeLimit::Infinite).unwrap();
check(&bincode::deserialize(&b).unwrap());
println!("{:?} // data=A", &b[0..1]);
println!("{:?} // left=marked", &b[1..6]);
println!("{:?} // id=0", &b[6..10]);
println!("{:?} // data=C", &b[10..11]);
println!("{:?} // left=plain", &b[11..16]);
println!("{:?} // data=D", &b[16..17]);
println!("{:?} // left=none right=none", &b[17..19]);
println!("{:?} // right=marked", &b[19..24]);
println!("{:?} // id=1", &b[24..28]);
println!("{:?} // data=E", &b[28..29]);
println!("{:?} // left=none right=none", &b[29..31]);
println!("{:?} // right=plain", &b[31..36]);
println!("{:?} // data=B", &b[36..37]);
println!("{:?} // left=reference", &b[37..42]);
println!("{:?} // id=0", &b[42..46]);
println!("{:?} // right=reference", &b[46..51]);
println!("{:?} // id=1", &b[51..55]);
assert_eq!(b.len(), 55);
println!("{:?}", b);
}
//////////////////////////////////////////////////////////////////////////////
type Id = u32;
type NodeToId = HashMap<*const Node, Id>;
type IdToNode = HashMap<Id, Rc<Node>>;
enum Lookup {
Unique,
Found(Id),
Inserted(Id),
}
fn node_to_id(map: &RefCell<NodeToId>, node: &Rc<Node>) -> Lookup {
if Rc::strong_count(node) == 1 {
return Lookup::Unique;
}
let mut map = map.borrow_mut();
if let Some(id) = map.get(&(&**node as *const Node)) {
return Lookup::Found(*id);
}
let id = map.len() as Id;
map.insert(&**node, id);
Lookup::Inserted(id)
}
//////////////////////////////////////////////////////////////////////////////
trait SerializeDag {
fn serialize_dag<S>(&self, serializer: S, map: &RefCell<NodeToId>) -> Result<S::Ok, S::Error>
where S: Serializer;
}
struct Tracked<'a> {
node: &'a Rc<Node>,
map: &'a RefCell<NodeToId>,
}
impl<'a> Tracked<'a> {
fn option(node: &'a Option<Rc<Node>>, map: &'a RefCell<NodeToId>) -> Option<Self> {
node.as_ref().map(|n| Tracked { node: n, map: map })
}
}
impl<'a> Serialize for Tracked<'a> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where S: Serializer
{
self.node.serialize_dag(serializer, self.map)
}
}
impl Serialize for Node {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where S: Serializer
{
let mut state = serializer.serialize_struct("Node", 3)?;
state.serialize_field("data", &self.data)?;
let map = RefCell::<NodeToId>::default();
state.serialize_field("left", &Tracked::option(&self.left, &map))?;
state.serialize_field("right", &Tracked::option(&self.right, &map))?;
state.end()
}
}
impl SerializeDag for Rc<Node> {
fn serialize_dag<S>(&self, serializer: S, map: &RefCell<NodeToId>) -> Result<S::Ok, S::Error>
where S: Serializer
{
match node_to_id(map, self) {
Lookup::Unique => {
let mut state = serializer.serialize_struct_variant("Node", 0, "Plain", 3)?;
state.serialize_field("data", &self.data)?;
state.serialize_field("left", &Tracked::option(&self.left, map))?;
state.serialize_field("right", &Tracked::option(&self.right, map))?;
state.end()
}
Lookup::Found(id) => {
serializer.serialize_newtype_variant("Node", 2, "Reference", &id)
}
Lookup::Inserted(id) => {
let mut state = serializer.serialize_struct_variant("Node", 1, "Marked", 4)?;
state.serialize_field("id", &id)?;
state.serialize_field("data", &self.data)?;
state.serialize_field("left", &Tracked::option(&self.left, map))?;
state.serialize_field("right", &Tracked::option(&self.right, map))?;
state.end()
}
}
}
}
//////////////////////////////////////////////////////////////////////////////
impl Deserialize for Node {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where D: Deserializer
{
deserializer.deserialize_struct("Node", FIELDS, RootNodeVisitor)
}
}
struct RootNodeVisitor;
impl Visitor for RootNodeVisitor {
type Value = Node;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("struct Node")
}
fn visit_seq<V>(self, mut visitor: V) -> Result<Node, V::Error>
where V: SeqVisitor
{
let mut map = IdToNode::default();
let data = visitor.visit()?
.ok_or_else(|| Error::invalid_length(0, &self))?;
let left = visitor.visit_seed(OptionSeed(NodeSeed { map: &mut map }))?
.ok_or_else(|| Error::invalid_length(1, &self))?;
let right = visitor.visit_seed(OptionSeed(NodeSeed { map: &mut map }))?
.ok_or_else(|| Error::invalid_length(2, &self))?;
Ok(Node { data: data, left: left, right: right })
}
fn visit_map<V>(self, mut visitor: V) -> Result<Node, V::Error>
where V: MapVisitor
{
let mut map = IdToNode::default();
let mut data = None;
let mut left = None;
let mut right = None;
while let Some(key) = visitor.visit_key()? {
match key {
Field::Data => {
if data.is_some() {
return Err(Error::duplicate_field("data"));
}
data = Some(visitor.visit_value()?);
}
Field::Left => {
if left.is_some() {
return Err(Error::duplicate_field("left"));
}
left = visitor.visit_value_seed(OptionSeed(NodeSeed { map: &mut map }))?;
}
Field::Right => {
if right.is_some() {
return Err(Error::duplicate_field("right"));
}
right = visitor.visit_value_seed(OptionSeed(NodeSeed { map: &mut map }))?;
}
}
}
let data = data.ok_or_else(|| Error::missing_field("data"))?;
Ok(Node { data: data, left: left, right: right })
}
}
//////////////////////////////////////////////////////////////////////////////
struct NodeSeed<'a> {
map: &'a mut IdToNode,
}
impl<'a> Visitor for NodeSeed<'a> {
type Value = Rc<Node>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("struct Node")
}
fn visit_enum<V>(self, visitor: V) -> Result<Self::Value, V::Error>
where V: EnumVisitor
{
match visitor.visit_variant()? {
(Variant::Plain, variant) => {
variant.visit_struct(FIELDS, PlainNodeVisitor { map: self.map })
}
(Variant::Marked, variant) => {
variant.visit_struct(MARKED_FIELDS, MarkedNodeVisitor { map: self.map })
}
(Variant::Reference, variant) => {
let id = variant.visit_newtype()?;
match self.map.get(&id) {
Some(rc) => Ok(rc.clone()),
None => Err(Error::custom(format_args!("missing id {}", id))),
}
}
}
}
}
impl<'a> DeserializeSeed for NodeSeed<'a> {
type Value = Rc<Node>;
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where D: Deserializer
{
deserializer.deserialize_enum("Node", VARIANTS, self)
}
}
struct PlainNodeVisitor<'a> {
map: &'a mut IdToNode,
}
impl<'a> Visitor for PlainNodeVisitor<'a> {
type Value = Rc<Node>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("struct Node")
}
fn visit_seq<V>(self, mut visitor: V) -> Result<Self::Value, V::Error>
where V: SeqVisitor
{
let data = visitor.visit()?
.ok_or_else(|| Error::invalid_length(0, &self))?;
let left = visitor.visit_seed(OptionSeed(NodeSeed { map: self.map }))?
.ok_or_else(|| Error::invalid_length(1, &self))?;
let right = visitor.visit_seed(OptionSeed(NodeSeed { map: self.map }))?
.ok_or_else(|| Error::invalid_length(2, &self))?;
Ok(Rc::new(Node { data: data, left: left, right: right }))
}
fn visit_map<V>(self, mut visitor: V) -> Result<Self::Value, V::Error>
where V: MapVisitor
{
let mut data = None;
let mut left = None;
let mut right = None;
while let Some(key) = visitor.visit_key()? {
match key {
Field::Data => {
if data.is_some() {
return Err(Error::duplicate_field("data"));
}
data = Some(visitor.visit_value()?);
}
Field::Left => {
if left.is_some() {
return Err(Error::duplicate_field("left"));
}
left = visitor.visit_value_seed(OptionSeed(NodeSeed { map: self.map }))?;
}
Field::Right => {
if right.is_some() {
return Err(Error::duplicate_field("right"));
}
right = visitor.visit_value_seed(OptionSeed(NodeSeed { map: self.map }))?;
}
}
}
let data = data.ok_or_else(|| Error::missing_field("data"))?;
Ok(Rc::new(Node { data: data, left: left, right: right }))
}
}
struct MarkedNodeVisitor<'a> {
map: &'a mut IdToNode,
}
impl<'a> Visitor for MarkedNodeVisitor<'a> {
type Value = Rc<Node>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("struct Node")
}
fn visit_seq<V>(self, mut visitor: V) -> Result<Self::Value, V::Error>
where V: SeqVisitor
{
let id = visitor.visit()?
.ok_or_else(|| Error::invalid_length(0, &self))?;
let data = visitor.visit()?
.ok_or_else(|| Error::invalid_length(1, &self))?;
let left = visitor.visit_seed(OptionSeed(NodeSeed { map: self.map }))?
.ok_or_else(|| Error::invalid_length(2, &self))?;
let right = visitor.visit_seed(OptionSeed(NodeSeed { map: self.map }))?
.ok_or_else(|| Error::invalid_length(3, &self))?;
let node = Rc::new(Node { data: data, left: left, right: right });
self.map.insert(id, node.clone());
Ok(node)
}
fn visit_map<V>(self, mut visitor: V) -> Result<Self::Value, V::Error>
where V: MapVisitor
{
let mut id = None;
let mut data = None;
let mut left = None;
let mut right = None;
while let Some(key) = visitor.visit_key()? {
match key {
MarkedField::Id => {
if id.is_some() {
return Err(Error::duplicate_field("id"));
}
id = Some(visitor.visit_value()?);
}
MarkedField::Data => {
if data.is_some() {
return Err(Error::duplicate_field("data"));
}
data = Some(visitor.visit_value()?);
}
MarkedField::Left => {
if left.is_some() {
return Err(Error::duplicate_field("left"));
}
left = visitor.visit_value_seed(OptionSeed(NodeSeed { map: self.map }))?;
}
MarkedField::Right => {
if right.is_some() {
return Err(Error::duplicate_field("right"));
}
right = visitor.visit_value_seed(OptionSeed(NodeSeed { map: self.map }))?;
}
}
}
let id = id.ok_or_else(|| Error::missing_field("id"))?;
let data = data.ok_or_else(|| Error::missing_field("data"))?;
let node = Rc::new(Node { data: data, left: left, right: right });
self.map.insert(id, node.clone());
Ok(node)
}
}
//////////////////////////////////////////////////////////////////////////////
/// Maybe this should be provided by Serde. Just turns any seed into an
/// optional one.
struct OptionSeed<S>(S);
impl<S> Visitor for OptionSeed<S>
where S: DeserializeSeed
{
type Value = Option<S::Value>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("option")
}
fn visit_none<E>(self) -> Result<Self::Value, E>
where E: Error
{
Ok(None)
}
fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where D: Deserializer
{
self.0.deserialize(deserializer).map(Some)
}
}
impl<S> DeserializeSeed for OptionSeed<S>
where S: DeserializeSeed
{
type Value = Option<S::Value>;
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where D: Deserializer
{
deserializer.deserialize_option(self)
}
}
//////////////////////////////////////////////////////////////////////////////
enum Variant { Plain, Marked, Reference }
enum Field { Data, Left, Right }
enum MarkedField { Id, Data, Left, Right }
const VARIANTS: &'static [&'static str] = &["Plain", "Marked", "Reference"];
const FIELDS: &'static [&'static str] = &["data", "left", "right"];
const MARKED_FIELDS: &'static [&'static str] = &["id", "data", "left", "right"];
impl Deserialize for Variant {
fn deserialize<D>(deserializer: D) -> Result<Variant, D::Error>
where D: Deserializer
{
struct VariantVisitor;
impl Visitor for VariantVisitor {
type Value = Variant;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("`Plain` or `Marked` or `Reference`")
}
fn visit_u32<E>(self, value: u32) -> Result<Self::Value, E>
where E: Error
{
match value {
0 => Ok(Variant::Plain),
1 => Ok(Variant::Marked),
2 => Ok(Variant::Reference),
_ => Err(Error::invalid_value(Unexpected::Unsigned(value as u64), &"0 <= i < 3")),
}
}
fn visit_str<E>(self, value: &str) -> Result<Variant, E>
where E: Error
{
match value {
"Plain" => Ok(Variant::Plain),
"Marked" => Ok(Variant::Marked),
"Reference" => Ok(Variant::Reference),
_ => Err(Error::unknown_field(value, VARIANTS)),
}
}
}
deserializer.deserialize_struct_field(VariantVisitor)
}
}
impl Deserialize for Field {
fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
where D: Deserializer
{
struct FieldVisitor;
impl Visitor for FieldVisitor {
type Value = Field;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("`data` or `left` or `right`")
}
fn visit_str<E>(self, value: &str) -> Result<Field, E>
where E: Error
{
match value {
"data" => Ok(Field::Data),
"left" => Ok(Field::Left),
"right" => Ok(Field::Right),
_ => Err(Error::unknown_field(value, FIELDS)),
}
}
}
deserializer.deserialize_struct_field(FieldVisitor)
}
}
impl Deserialize for MarkedField {
fn deserialize<D>(deserializer: D) -> Result<MarkedField, D::Error>
where D: Deserializer
{
struct MarkedFieldVisitor;
impl Visitor for MarkedFieldVisitor {
type Value = MarkedField;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("`id` or `data` or `left` or `right`")
}
fn visit_str<E>(self, value: &str) -> Result<MarkedField, E>
where E: Error
{
match value {
"id" => Ok(MarkedField::Id),
"data" => Ok(MarkedField::Data),
"left" => Ok(MarkedField::Left),
"right" => Ok(MarkedField::Right),
_ => Err(Error::unknown_field(value, MARKED_FIELDS)),
}
}
}
deserializer.deserialize_struct_field(MarkedFieldVisitor)
}
}
@matthewhammer
Copy link

Just found this code, and have started to read it. It looks interesting, and worth a closer study.

Here's my version of a similar solution, to a similar (or the same?) "problem" with serde and Rc's: https://github.com/Adapton/hashcons.rust

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment