Skip to content

Instantly share code, notes, and snippets.

@ctron
Created May 5, 2021 09:17
Show Gist options
  • Save ctron/3fa7f9912da044bd1e15331a3676cfd6 to your computer and use it in GitHub Desktop.
Save ctron/3fa7f9912da044bd1e15331a3676cfd6 to your computer and use it in GitHub Desktop.
A streaming JSON serializer
// Copyright 2021 Red Hat Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use actix_http::{http::StatusCode, Response};
use bytes::{BufMut, Bytes, BytesMut};
use core::fmt::Debug;
use futures::{
task::{Context, Poll},
{ready, Stream},
};
use pin_project::pin_project;
use serde::Serialize;
use std::{
fmt::{Display, Formatter},
pin::Pin,
};
/// The internal state of the stream
enum State {
/// Before the first item
Start,
/// In the middle of processing
Data,
/// After the last item
End,
}
#[derive(Debug)]
pub enum ArrayStreamerError<E>
where
E: Debug + Display,
{
Source(E),
Serializer(serde_json::Error),
}
impl<E> Display for ArrayStreamerError<E>
where
E: Debug + Display,
{
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
match self {
Self::Source(err) => write!(f, "Source error: {}", err),
Self::Serializer(err) => write!(f, "Serializer error: {}", err),
}
}
}
impl<E> actix_http::ResponseError for ArrayStreamerError<E>
where
E: Debug + Display + actix_http::ResponseError,
{
fn status_code(&self) -> StatusCode {
match self {
Self::Source(err) => err.status_code(),
_ => StatusCode::INTERNAL_SERVER_ERROR,
}
}
fn error_response(&self) -> Response {
match self {
Self::Source(err) => err.error_response(),
Self::Serializer(err) => Response::InternalServerError().body(err.to_string()),
}
}
}
#[pin_project]
pub struct ArrayStreamer<S, T, E>
where
S: Stream<Item = Result<T, E>>,
T: Serialize,
E: Debug + Display,
{
#[pin]
stream: S,
state: State,
}
impl<S, T, E> ArrayStreamer<S, T, E>
where
S: Stream<Item = Result<T, E>>,
T: Serialize,
E: Debug + Display,
{
pub fn new(stream: S) -> Self {
Self {
stream,
state: State::Start,
}
}
}
impl<S, T, E> Stream for ArrayStreamer<S, T, E>
where
S: Stream<Item = Result<T, E>>,
T: Serialize,
E: Debug + Display,
{
type Item = Result<Bytes, ArrayStreamerError<E>>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if matches!(self.state, State::End) {
return Poll::Ready(None);
}
let mut this = self.project();
let mut data = BytesMut::new();
if matches!(this.state, State::Start) {
data.put_u8(b'[')
}
let res = ready!(this.stream.as_mut().poll_next(cx));
match res {
Some(Err(err)) => return Poll::Ready(Some(Err(ArrayStreamerError::Source(err)))),
Some(Ok(item)) => {
// first/next item
if matches!(this.state, State::Data) {
data.put_u8(b',');
}
// serialize
match serde_json::to_vec(&item) {
Ok(buffer) => data.put(Bytes::from(buffer)),
Err(err) => return Poll::Ready(Some(Err(ArrayStreamerError::Serializer(err)))),
}
// change state after encoding
*this.state = State::Data;
}
None => {
// no more content
*this.state = State::End;
}
};
if matches!(this.state, State::End) {
data.put_u8(b']');
}
if data.is_empty() {
Poll::Ready(None)
} else {
Poll::Ready(Some(Ok(data.into())))
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.stream.size_hint()
}
}
#[cfg(test)]
mod test {
use super::*;
use futures::{stream, TryStreamExt};
#[tokio::test]
async fn test_streamer_default() {
let data: Vec<Result<_, String>> = vec![Ok("foo"), Ok("bar")];
let streamer = ArrayStreamer::new(stream::iter(data));
let outcome: Vec<Bytes> = streamer.try_collect().await.unwrap();
let outcome: String = outcome
.into_iter()
.map(|b| String::from_utf8(b.to_vec()).unwrap_or_default())
.collect();
assert_eq!(outcome, r#"["foo","bar"]"#);
}
#[tokio::test]
async fn test_streamer_empty() {
let data: Vec<Result<String, String>> = vec![];
let streamer = ArrayStreamer::new(stream::iter(data));
let outcome: Vec<Bytes> = streamer.try_collect().await.unwrap();
let outcome: String = outcome
.into_iter()
.map(|b| String::from_utf8(b.to_vec()).unwrap_or_default())
.collect();
assert_eq!(outcome, r#"[]"#);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment