Skip to content

Instantly share code, notes, and snippets.

@rkuhn
Created May 22, 2022 10:12
Show Gist options
  • Save rkuhn/413aa0cb4f7415bbb10c3cddd1fa0615 to your computer and use it in GitHub Desktop.
Save rkuhn/413aa0cb4f7415bbb10c3cddd1fa0615 to your computer and use it in GitHub Desktop.
A streaming-response protocol for rust-libp2p
use super::{
handler::{self, IntoHandler, Request, Response},
RequestReceived, StreamingResponseConfig,
};
use crate::Codec;
use futures::channel::mpsc;
use libp2p::{
core::connection::ConnectionId,
swarm::{NetworkBehaviour, NetworkBehaviourAction, NotifyHandler, PollParameters},
PeerId,
};
use std::{
collections::VecDeque,
marker::PhantomData,
task::{Context, Poll},
};
pub struct StreamingResponse<T: Codec + Send + 'static> {
config: StreamingResponseConfig,
events: VecDeque<RequestReceived<T>>,
requests: VecDeque<NetworkBehaviourAction<RequestReceived<T>, IntoHandler<T>>>,
_ph: PhantomData<T>,
}
impl<T: Codec + Send + 'static> StreamingResponse<T> {
pub fn new(config: StreamingResponseConfig) -> Self {
Self {
config,
events: VecDeque::default(),
requests: VecDeque::default(),
_ph: PhantomData,
}
}
pub fn request(&mut self, peer_id: PeerId, request: T::Request, channel: mpsc::Sender<Response<T::Response>>) {
self.requests.push_back(NetworkBehaviourAction::NotifyHandler {
peer_id,
handler: NotifyHandler::Any,
event: Request::new(request, channel),
})
}
}
impl<T: Codec + Send + 'static> NetworkBehaviour for StreamingResponse<T> {
type ProtocolsHandler = IntoHandler<T>;
type OutEvent = RequestReceived<T>;
fn new_handler(&mut self) -> Self::ProtocolsHandler {
IntoHandler::new(
self.config.spawner.clone(),
self.config.max_message_size,
self.config.request_timeout,
self.config.response_send_buffer_size,
self.config.keep_alive,
)
}
fn inject_event(
&mut self,
peer_id: PeerId,
connection: ConnectionId,
event: <<Self::ProtocolsHandler as libp2p::swarm::IntoProtocolsHandler>::Handler as libp2p::swarm::ProtocolsHandler>::OutEvent,
) {
let handler::RequestReceived { request, channel } = event;
log::trace!("request received by behaviour: {:?}", request);
self.events.push_back(RequestReceived {
peer_id,
connection,
request,
channel,
});
}
fn poll(
&mut self,
_cx: &mut Context<'_>,
_params: &mut impl PollParameters,
) -> Poll<NetworkBehaviourAction<Self::OutEvent, Self::ProtocolsHandler>> {
if let Some(action) = self.requests.pop_front() {
log::trace!("triggering request action");
return Poll::Ready(action);
}
match self.events.pop_front() {
Some(e) => Poll::Ready(NetworkBehaviourAction::GenerateEvent(e)),
None => Poll::Pending,
}
}
}
use super::{
protocol::{self, Requester, Responder},
ProtocolError,
};
use crate::Codec;
use futures::{
channel::mpsc, future::BoxFuture, stream::FuturesUnordered, AsyncWriteExt, FutureExt, SinkExt, StreamExt,
};
use libp2p::{
core::{ConnectedPoint, UpgradeError},
swarm::{
protocols_handler::{InboundUpgradeSend, OutboundUpgradeSend},
IntoProtocolsHandler, KeepAlive, ProtocolsHandler, ProtocolsHandlerEvent, ProtocolsHandlerUpgrErr,
SubstreamProtocol,
},
PeerId,
};
use std::{
any::Any,
collections::VecDeque,
fmt::Debug,
io::ErrorKind,
marker::PhantomData,
sync::Arc,
task::{Context, Poll},
time::Duration,
};
#[derive(Debug, PartialEq)]
pub enum Response<T> {
Msg(T),
Error(ProtocolError),
Finished,
}
impl<T> Response<T> {
pub fn into_msg(self) -> Result<T, ProtocolError> {
match self {
Response::Msg(msg) => Ok(msg),
Response::Error(e) => Err(e),
Response::Finished => Err(ProtocolError::Io(ErrorKind::UnexpectedEof.into())),
}
}
}
pub struct Request<T: Codec> {
request: T::Request,
channel: mpsc::Sender<Response<T::Response>>,
}
impl<T: Codec> Request<T> {
pub fn new(request: T::Request, channel: mpsc::Sender<Response<T::Response>>) -> Self {
Self { request, channel }
}
}
impl<T: Codec> Debug for Request<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Request").field("request", &self.request).finish()
}
}
pub struct RequestReceived<T: Codec> {
pub(crate) request: T::Request,
pub(crate) channel: mpsc::Sender<T::Response>,
}
impl<T: Codec> Debug for RequestReceived<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RequestReceived")
.field("request", &self.request)
.finish()
}
}
pub struct IntoHandler<T> {
spawner: Spawner,
max_message_size: u32,
request_timeout: Duration,
response_send_buffer_size: usize,
keep_alive: bool,
_ph: PhantomData<T>,
}
impl<T> IntoHandler<T> {
pub fn new(
spawner: Spawner,
max_message_size: u32,
request_timeout: Duration,
response_send_buffer_size: usize,
keep_alive: bool,
) -> Self {
Self {
spawner,
max_message_size,
request_timeout,
response_send_buffer_size,
keep_alive,
_ph: PhantomData,
}
}
}
impl<T: Codec + Send + 'static> IntoProtocolsHandler for IntoHandler<T> {
type Handler = Handler<T>;
fn into_handler(self, _remote_peer_id: &PeerId, _connected_point: &ConnectedPoint) -> Self::Handler {
Handler::new(
self.spawner,
self.max_message_size,
self.request_timeout,
self.response_send_buffer_size,
self.keep_alive,
)
}
fn inbound_protocol(&self) -> <Self::Handler as ProtocolsHandler>::InboundProtocol {
Responder::new(self.max_message_size)
}
}
type ProtocolEvent<T> = ProtocolsHandlerEvent<
Requester<T>,
mpsc::Sender<Response<<T as Codec>::Response>>,
RequestReceived<T>,
ProtocolError,
>;
pub type ResponseFuture = BoxFuture<'static, Box<dyn Any + Send + 'static>>;
pub type Spawner = Arc<dyn Fn(ResponseFuture) -> ResponseFuture + Send + Sync + 'static>;
pub struct Handler<T: Codec> {
events: VecDeque<ProtocolEvent<T>>,
streams: FuturesUnordered<ResponseFuture>,
spawner: Spawner,
max_message_size: u32,
request_timeout: Duration,
response_send_buffer_size: usize,
keep_alive: bool,
}
impl<T: Codec> Debug for Handler<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Handler")
.field("events", &self.events.len())
.field("streams", &self.streams.len())
.finish()
}
}
impl<T: Codec> Handler<T> {
pub fn new(
spawner: Spawner,
max_message_size: u32,
request_timeout: Duration,
response_send_buffer_size: usize,
keep_alive: bool,
) -> Self {
Self {
events: VecDeque::default(),
streams: FuturesUnordered::default(),
spawner,
max_message_size,
request_timeout,
response_send_buffer_size,
keep_alive,
}
}
}
impl<T: Codec + Send + 'static> ProtocolsHandler for Handler<T> {
type InEvent = Request<T>;
type OutEvent = RequestReceived<T>;
type Error = ProtocolError;
type InboundProtocol = Responder<T>;
type OutboundProtocol = Requester<T>;
type InboundOpenInfo = ();
type OutboundOpenInfo = mpsc::Sender<Response<T::Response>>;
fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
SubstreamProtocol::new(Responder::new(self.max_message_size), ()).with_timeout(self.request_timeout)
}
fn inject_fully_negotiated_inbound(
&mut self,
protocol: <Self::InboundProtocol as InboundUpgradeSend>::Output,
_info: Self::InboundOpenInfo,
) {
let (request, mut stream) = protocol;
let (channel, mut rx) = mpsc::channel(self.response_send_buffer_size);
let max_message_size = self.max_message_size;
log::trace!("handler received request");
let task = (self.spawner)(
async move {
log::trace!("starting send loop");
let mut buffer = Vec::new();
loop {
// only flush once we’re going to sleep
let response = match rx.try_next() {
Ok(Some(r)) => r,
Ok(None) => break,
Err(_) => {
log::trace!("flushing stream");
stream.flush().await?;
match rx.next().await {
Some(r) => r,
None => break,
}
}
};
protocol::write_msg(&mut stream, response, max_message_size, &mut buffer).await?;
}
log::trace!("flushing and closing substream");
protocol::write_finish(&mut stream).await?;
Result::<_, ProtocolError>::Ok(())
}
.map(|res| -> Box<dyn Any + Send + 'static> { Box::new(res) })
.boxed(),
);
self.streams.push(task);
self.events
.push_back(ProtocolsHandlerEvent::Custom(RequestReceived { request, channel }));
}
fn inject_fully_negotiated_outbound(
&mut self,
mut stream: <Self::OutboundProtocol as OutboundUpgradeSend>::Output,
mut tx: Self::OutboundOpenInfo,
) {
let max_message_size = self.max_message_size;
let task = (self.spawner)(
async move {
log::trace!("starting receive loop");
let mut buffer = Vec::new();
loop {
match protocol::read_msg(&mut stream, max_message_size, &mut buffer)
.await
.unwrap_or_else(Response::Error)
{
Response::Msg(msg) => {
tx.feed(Response::Msg(msg)).await?;
log::trace!("response sent to client code");
}
Response::Error(e) => {
log::debug!("sending substream error {}", e);
tx.feed(Response::Error(e)).await?;
return Result::<_, ProtocolError>::Ok(());
}
Response::Finished => {
log::trace!("finishing substream");
tx.feed(Response::Finished).await?;
return Ok(());
}
}
}
}
.map(|res| -> Box<dyn Any + Send + 'static> { Box::new(res) })
.boxed(),
);
self.streams.push(task);
}
fn inject_event(&mut self, command: Self::InEvent) {
let Request { request, channel } = command;
log::trace!("requesting {:?}", request);
self.events.push_back(ProtocolsHandlerEvent::OutboundSubstreamRequest {
protocol: SubstreamProtocol::new(Requester::new(self.max_message_size, request), channel)
.with_timeout(self.request_timeout),
})
}
fn inject_dial_upgrade_error(
&mut self,
mut tx: Self::OutboundOpenInfo,
error: ProtocolsHandlerUpgrErr<<Self::OutboundProtocol as OutboundUpgradeSend>::Error>,
) {
let error = match error {
ProtocolsHandlerUpgrErr::Timeout => ProtocolError::Timeout,
ProtocolsHandlerUpgrErr::Timer => ProtocolError::Timeout,
ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(e)) => e,
ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e)) => e.into(),
};
log::debug!("dial upgrade error: {}", error);
if let Err(Response::Error(e)) = tx.try_send(Response::Error(error)).map_err(|e| e.into_inner()) {
log::warn!("cannot send upgrade error to requester: {}", e);
}
}
fn connection_keep_alive(&self) -> KeepAlive {
if !self.keep_alive && self.streams.is_empty() {
KeepAlive::No
} else {
KeepAlive::Yes
}
}
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<ProtocolEvent<T>> {
loop {
if self.streams.is_empty() {
break;
}
if let Poll::Ready(result) = self.streams.poll_next_unpin(cx) {
// since the set was not empty, this must be a Some()
if let Some(Err(e)) = result.and_then(|e| e.downcast::<Result<(), ProtocolError>>().ok().map(|b| *b)) {
// no need to tear down the connection, substream is already closed
log::warn!("error in substream task: {}", e);
}
} else {
break;
}
}
match self.events.pop_front() {
Some(e) => Poll::Ready(e),
None => Poll::Pending,
}
}
}
mod behaviour;
mod handler;
mod protocol;
#[cfg(test)]
mod tests;
pub use behaviour::StreamingResponse;
pub use handler::{Response, ResponseFuture, Spawner};
pub use protocol::ProtocolError;
use crate::Codec;
use futures::channel::mpsc;
use libp2p::{core::connection::ConnectionId, PeerId};
use std::{fmt::Debug, sync::Arc, time::Duration};
pub struct RequestReceived<T: Codec> {
pub peer_id: PeerId,
pub connection: ConnectionId,
pub request: T::Request,
pub channel: mpsc::Sender<T::Response>,
}
impl<T: Codec> Debug for RequestReceived<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RequestReceived")
.field("peer_id", &self.peer_id)
.field("connection", &self.connection)
.field("request", &self.request)
.finish()
}
}
pub struct StreamingResponseConfig {
pub spawner: Spawner,
pub request_timeout: Duration,
pub max_message_size: u32,
pub response_send_buffer_size: usize,
pub keep_alive: bool,
}
impl StreamingResponseConfig {
/// Spawn response stream handling tasks using the given function
///
/// This function may be called from an arbitrary context, you cannot assume that because
/// you’re using Tokio this will happen on a Tokio thread. Hence it is necessary to point
/// to the target thread pool directly, e.g. by using a runtime handle.
///
/// If this method is not used, tasks will be polled via the Swarm, which may be an I/O
/// bottleneck.
pub fn with_spawner(self, spawner: impl Fn(ResponseFuture) -> ResponseFuture + Send + Sync + 'static) -> Self {
Self {
spawner: Arc::new(spawner),
..self
}
}
/// Timeout for the transmission of the request to the peer, default is 10sec
pub fn with_request_timeout(self, request_timeout: Duration) -> Self {
Self {
request_timeout,
..self
}
}
/// Maximum message size permitted for requests and responses
///
/// The maximum is 4GiB, the default 1MB. Sending huge messages requires corresponding
/// buffers and may not be desirable.
pub fn with_max_message_size(self, max_message_size: u32) -> Self {
Self {
max_message_size,
..self
}
}
/// Set the queue size in messages for the channel created for incoming requests
///
/// All channels are bounded in size and use back-pressure. This channel size allows some
/// decoupling between response generation and network transmission. Default is 128.
pub fn with_response_send_buffer_size(self, response_send_buffer_size: usize) -> Self {
Self {
response_send_buffer_size,
..self
}
}
/// If this is set to true, then this behaviour will keep the connection alive
///
/// Otherwise the connection is released (i.e. closed if no other behaviour keeps it alive)
/// when there are no active requests ongoing. Default is `false`.
pub fn with_keep_alive(self, keep_alive: bool) -> Self {
Self { keep_alive, ..self }
}
}
impl Default for StreamingResponseConfig {
fn default() -> Self {
Self {
spawner: Arc::new(|f| f),
request_timeout: Duration::from_secs(10),
max_message_size: 1_000_000,
response_send_buffer_size: 128,
keep_alive: false,
}
}
}
use super::handler::Response;
use crate::Codec;
use derive_more::{Display, Error, From};
use futures::{channel::mpsc, future::BoxFuture, AsyncReadExt, AsyncWriteExt, FutureExt};
use libp2p::{
core::{upgrade::NegotiationError, UpgradeInfo},
swarm::NegotiatedSubstream,
InboundUpgrade, OutboundUpgrade,
};
use serde::de::DeserializeOwned;
use std::{
fmt::{Display, Write},
io::ErrorKind,
iter::{once, Once},
marker::PhantomData,
};
#[derive(Error, Display, Debug, From)]
pub enum ProtocolError {
#[display(fmt = "timeout while waiting for request receive confirmation")]
Timeout,
#[display(fmt = "message too large received: {}", _0)]
#[from(ignore)]
MessageTooLargeRecv(#[error(ignore)] usize),
#[display(fmt = "message too large sent: {}", _0)]
#[from(ignore)]
MessageTooLargeSent(#[error(ignore)] usize),
#[display(fmt = "substream protocol negotiation error: {}", _0)]
Negotiation(NegotiationError),
#[display(fmt = "I/O error: {}", _0)]
Io(std::io::Error),
#[display(fmt = "(de)serialisation error: {}", _0)]
Serde(serde_cbor::Error),
#[display(fmt = "internal channel error")]
Channel(mpsc::SendError),
/// This variant is useful for implementing the function to pass to
/// [`with_spawner`](crate::v2::StreamingResponseConfig)
#[display(fmt = "spawned task failed (cancelled={})", _0)]
JoinError(#[error(ignore)] bool),
}
impl PartialEq for ProtocolError {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::MessageTooLargeRecv(l0), Self::MessageTooLargeRecv(r0)) => l0 == r0,
(Self::MessageTooLargeSent(l0), Self::MessageTooLargeSent(r0)) => l0 == r0,
(Self::Negotiation(l0), Self::Negotiation(r0)) => l0.to_string() == r0.to_string(),
(Self::Io(l0), Self::Io(r0)) => l0.to_string() == r0.to_string(),
(Self::Serde(l0), Self::Serde(r0)) => l0.to_string() == r0.to_string(),
(Self::Channel(l0), Self::Channel(r0)) => l0 == r0,
(Self::JoinError(l0), Self::JoinError(r0)) => l0 == r0,
_ => core::mem::discriminant(self) == core::mem::discriminant(other),
}
}
}
impl ProtocolError {
pub fn as_code(&self) -> u8 {
match self {
ProtocolError::Timeout => 1,
ProtocolError::MessageTooLargeRecv(_) => 2,
ProtocolError::MessageTooLargeSent(_) => 3,
ProtocolError::Negotiation(_) => 4,
ProtocolError::Io(_) => 5,
ProtocolError::Serde(_) => 6,
ProtocolError::Channel(_) => 7,
ProtocolError::JoinError(_) => 8,
}
}
pub fn from_code(code: u8) -> Self {
match code {
1 => ProtocolError::Timeout,
2 => ProtocolError::MessageTooLargeRecv(0),
3 => ProtocolError::MessageTooLargeSent(0),
4 => ProtocolError::Negotiation(NegotiationError::Failed),
5 => ProtocolError::Io(std::io::Error::new(ErrorKind::Other, "some error on peer")),
6 => ProtocolError::Serde(std::io::Error::new(ErrorKind::Other, "serde error on peer").into()),
7 => {
let (mut tx, _) = mpsc::channel(1);
let err = tx.try_send(0).unwrap_err().into_send_error();
ProtocolError::Channel(err)
}
8 => ProtocolError::JoinError(false),
n => ProtocolError::Io(std::io::Error::new(
ErrorKind::Other,
format!("unknown error code {}", n),
)),
}
}
}
pub async fn write_msg(
io: &mut NegotiatedSubstream,
msg: impl serde::Serialize,
max_size: u32,
buffer: &mut Vec<u8>,
) -> Result<(), ProtocolError> {
buffer.resize(4, 0);
let res = serde_cbor::to_writer(&mut *buffer, &msg);
if let Err(e) = res {
let err = ProtocolError::Serde(e);
write_err(io, &err).await?;
return Err(err);
}
let size = buffer.len() - 4;
if size > (max_size as usize) {
log::debug!("message size {} too large (max = {})", size, max_size);
let err = ProtocolError::MessageTooLargeSent(size);
write_err(io, &err).await?;
return Err(err);
}
log::trace!("sending message of size {}", size);
buffer.as_mut_slice()[..4].copy_from_slice(&(size as u32).to_be_bytes());
io.write_all(buffer.as_slice()).await?;
Ok(())
}
pub async fn write_err(io: &mut NegotiatedSubstream, err: &ProtocolError) -> Result<(), std::io::Error> {
let buf = [255, err.as_code()];
io.write_all(&buf).await?;
io.flush().await?;
io.close().await?;
Ok(())
}
pub async fn write_finish(io: &mut NegotiatedSubstream) -> Result<(), std::io::Error> {
let buf = [255, 0];
io.write_all(&buf).await?;
io.flush().await?;
io.close().await?;
Ok(())
}
pub async fn read_msg<T: DeserializeOwned>(
io: &mut NegotiatedSubstream,
max_size: u32,
buffer: &mut Vec<u8>,
) -> Result<Response<T>, ProtocolError> {
let mut size_bytes = [0u8; 4];
let mut to_read = &mut size_bytes[..];
while !to_read.is_empty() {
let read = io.read(to_read).await?;
log::trace!("read {} header bytes", read);
if read == 0 {
let len = to_read.len();
let read = &size_bytes[..4 - len];
if read.len() != 2 || read[0] != 255 {
return Err(ProtocolError::Io(ErrorKind::UnexpectedEof.into()));
} else {
return match read[1] {
0 => Ok(Response::Finished),
n => Err(ProtocolError::from_code(n)),
};
}
}
to_read = to_read.split_at_mut(read).1;
}
let size = u32::from_be_bytes(size_bytes);
if size > max_size {
log::debug!("message size {} too large (max = {})", size, max_size);
let mut bytes = [0u8; 4096];
bytes[..4].copy_from_slice(&size_bytes);
let n = io.read(&mut bytes[4..]).await?;
log::debug!("{:?}", &bytes[..n + 4]);
return Err(ProtocolError::MessageTooLargeRecv(size as usize));
}
log::trace!("received header: msg is {} bytes", size);
buffer.resize(size as usize, 0);
io.read_exact(buffer.as_mut_slice()).await?;
log::trace!("all bytes read");
Ok(Response::Msg(serde_cbor::from_slice(buffer.as_slice())?))
}
#[derive(Debug)]
pub struct Responder<T> {
max_message_size: u32,
_ph: PhantomData<T>,
}
impl<T> Responder<T> {
pub fn new(max_message_size: u32) -> Self {
Self {
max_message_size,
_ph: PhantomData,
}
}
}
impl<T: Codec> UpgradeInfo for Responder<T> {
type Info = &'static [u8];
type InfoIter = Once<&'static [u8]>;
fn protocol_info(&self) -> Self::InfoIter {
once(T::protocol_info())
}
}
struct ProtoNameDisplay(&'static [u8]);
impl Display for ProtoNameDisplay {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
for byte in self.0 {
if *byte > 31 && *byte < 128 {
f.write_char((*byte).into())?;
} else {
f.write_char('\u{fffd}')?;
}
}
Ok(())
}
}
impl<T: Codec> InboundUpgrade<NegotiatedSubstream> for Responder<T> {
type Output = (T::Request, NegotiatedSubstream);
type Error = ProtocolError;
type Future = BoxFuture<'static, Result<Self::Output, Self::Error>>;
fn upgrade_inbound(self, mut socket: NegotiatedSubstream, info: Self::Info) -> Self::Future {
let max_message_size = self.max_message_size;
async move {
log::trace!("starting inbound upgrade `{}`", ProtoNameDisplay(info));
let msg = read_msg(&mut socket, max_message_size, &mut Vec::new())
.await?
.into_msg()?;
log::trace!("request received: {:?}", msg);
Ok((msg, socket))
}
.boxed()
}
}
#[derive(Debug)]
pub struct Requester<T: Codec> {
max_message_size: u32,
request: T::Request,
}
impl<T: Codec> Requester<T> {
pub fn new(max_message_size: u32, request: T::Request) -> Self {
Self {
max_message_size,
request,
}
}
}
impl<T: Codec> UpgradeInfo for Requester<T> {
type Info = &'static [u8];
type InfoIter = Once<&'static [u8]>;
fn protocol_info(&self) -> Self::InfoIter {
once(T::protocol_info())
}
}
impl<T: Codec> OutboundUpgrade<NegotiatedSubstream> for Requester<T> {
type Output = NegotiatedSubstream;
type Error = ProtocolError;
type Future = BoxFuture<'static, Result<Self::Output, Self::Error>>;
fn upgrade_outbound(self, mut socket: NegotiatedSubstream, info: Self::Info) -> Self::Future {
let Self {
max_message_size,
request,
} = self;
async move {
log::trace!("starting output upgrade `{}`", ProtoNameDisplay(info));
write_msg(&mut socket, request, max_message_size, &mut Vec::new()).await?;
socket.flush().await?;
log::trace!("all bytes sent");
Ok(socket)
}
.boxed()
}
}
use super::{ProtocolError, StreamingResponse, StreamingResponseConfig};
use crate::{
v2::{handler::Response, RequestReceived},
Codec,
};
use futures::{
channel::mpsc::{self, Receiver, Sender},
Future, FutureExt, SinkExt, StreamExt,
};
use libp2p::{
core::{transport::MemoryTransport, upgrade::Version},
identity::Keypair,
mplex::MplexConfig,
multiaddr::Protocol,
plaintext::PlainText2Config,
swarm::{SwarmBuilder, SwarmEvent},
Multiaddr, PeerId, Swarm, Transport,
};
use tokio::runtime::{Handle, Runtime};
use tracing_subscriber::{fmt::format::FmtSpan, util::SubscriberInitExt, EnvFilter};
mod proto;
const PROTO: &[u8] = b"/my/test";
fn test_swarm(use_spawner: Option<Handle>) -> Swarm<StreamingResponse<Proto>> {
let local_key = Keypair::generate_ed25519();
let local_public_key = local_key.public();
let local_peer_id = local_public_key.clone().into();
let transport = MemoryTransport::default()
.upgrade(Version::V1)
.authenticate(PlainText2Config { local_public_key })
.multiplex(MplexConfig::new())
.boxed();
let mut config = StreamingResponseConfig::default()
.with_keep_alive(true)
.with_max_message_size(100);
#[allow(clippy::redundant_closure)]
if let Some(rt) = use_spawner {
config = config.with_spawner(move |f| rt.spawn(f).map(|r| r.unwrap_or_else(|e| Box::new(e))).boxed());
}
let behaviour = StreamingResponse::new(config);
SwarmBuilder::new(transport, behaviour, local_peer_id).build()
}
fn fake_swarm(rt: &Runtime, bytes: &[u8]) -> Swarm<proto::TestBehaviour> {
let local_key = Keypair::generate_ed25519();
let local_public_key = local_key.public();
let local_peer_id = local_public_key.clone().into();
let transport = MemoryTransport::default()
.upgrade(Version::V1)
.authenticate(PlainText2Config { local_public_key })
.multiplex(MplexConfig::new())
.boxed();
let behaviour = proto::TestBehaviour(rt.handle().clone(), bytes.to_owned());
SwarmBuilder::new(transport, behaviour, local_peer_id).build()
}
struct Proto;
impl Codec for Proto {
type Request = String;
type Response = String;
fn protocol_info() -> &'static [u8] {
PROTO
}
}
macro_rules! wait4 {
($s:ident, $p:pat => $e:expr) => {
loop {
let ev = $s.next().await;
if ev.is_none() {
panic!("{} STOPPED", stringify!($s))
}
let ev = ev.unwrap();
log::info!("{} got {:?}", stringify!($s), ev);
if let $p = ev {
break $e;
}
}
};
}
macro_rules! task {
($s:ident $(, $p:pat => $e:expr)*) => {
tokio::spawn(async move {
while let Some(ev) = $s.next().await {
log::info!("{} got {:?}", stringify!($s), ev);
match ev {
$($p => ($e),)*
_ => {}
}
}
log::info!("{} STOPPED", stringify!($s));
})
};
}
fn dbg<T: std::fmt::Debug>(x: T) -> String {
format!("{:?}", x)
}
fn setup_logger() {
tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env())
.with_span_events(FmtSpan::ENTER | FmtSpan::CLOSE)
.finish()
.try_init()
.ok();
}
#[test]
fn smoke() {
setup_logger();
let rt = Runtime::new().unwrap();
let mut asker = test_swarm(None);
let asker_id = *asker.local_peer_id();
let mut responder = test_swarm(None);
let responder_id = *responder.local_peer_id();
asker.listen_on(Multiaddr::empty().with(Protocol::Memory(0))).unwrap();
rt.block_on(async move {
let addr = wait4!(asker, SwarmEvent::NewListenAddr { address, .. } => address);
responder.dial(addr).unwrap();
task!(responder,
SwarmEvent::Behaviour(RequestReceived { request, peer_id, mut channel, .. }) => {
tokio::spawn(async move {
channel.feed(request).await.unwrap();
channel.feed(peer_id.to_string()).await.unwrap();
channel.close().await.unwrap();
});
}
);
let peer_id = wait4!(asker, SwarmEvent::ConnectionEstablished { peer_id, .. } => peer_id);
assert_eq!(peer_id, responder_id);
let (tx, rx) = mpsc::channel(10);
asker.behaviour_mut().request(peer_id, "request".to_owned(), tx);
task!(asker);
let response = rx
.map(|r| match r {
Response::Msg(m) => Some(m),
Response::Error(e) => panic!("got error: {:#}", e),
Response::Finished => None,
})
.collect::<Vec<_>>()
.await;
assert_eq!(
response,
vec![Some("request".to_owned()), Some(asker_id.to_string()), None]
);
});
}
#[test]
fn smoke_executor() {
setup_logger();
let rt = Runtime::new().unwrap();
let mut asker = test_swarm(Some(rt.handle().clone()));
let asker_id = *asker.local_peer_id();
let mut responder = test_swarm(Some(rt.handle().clone()));
let responder_id = *responder.local_peer_id();
asker.listen_on(Multiaddr::empty().with(Protocol::Memory(0))).unwrap();
rt.block_on(async move {
let addr = wait4!(asker, SwarmEvent::NewListenAddr { address, .. } => address);
responder.dial(addr).unwrap();
task!(responder,
SwarmEvent::Behaviour(RequestReceived { request, peer_id, mut channel, .. }) => {
tokio::spawn(async move {
channel.feed(request).await.unwrap();
channel.feed(peer_id.to_string()).await.unwrap();
channel.close().await.unwrap();
});
}
);
let peer_id = wait4!(asker, SwarmEvent::ConnectionEstablished { peer_id, .. } => peer_id);
assert_eq!(peer_id, responder_id);
let (tx, rx) = mpsc::channel(10);
asker.behaviour_mut().request(peer_id, "request".to_owned(), tx);
task!(asker);
let response = rx
.map(|r| match r {
Response::Msg(m) => Some(m),
Response::Error(e) => panic!("got error: {:#}", e),
Response::Finished => None,
})
.collect::<Vec<_>>()
.await;
assert_eq!(
response,
vec![Some("request".to_owned()), Some(asker_id.to_string()), None]
);
});
}
fn test_setup<F, Fut, L>(request: String, logic: L, f: F)
where
F: FnOnce(Receiver<Response<String>>) -> Fut + Send + 'static,
Fut: Future,
L: Fn(String, PeerId, Sender<String>) + Send + 'static,
{
setup_logger();
let rt = Runtime::new().unwrap();
let mut asker = test_swarm(None);
let mut responder = test_swarm(None);
rt.block_on(async move {
responder
.listen_on(Multiaddr::empty().with(Protocol::Memory(0)))
.unwrap();
let addr = wait4!(responder, SwarmEvent::NewListenAddr{ address, .. } => address);
task!(responder, SwarmEvent::Behaviour(RequestReceived { request, peer_id, channel, .. }) => logic(request, peer_id, channel));
asker.dial(addr).unwrap();
let peer_id = wait4!(asker, SwarmEvent::ConnectionEstablished { peer_id, .. } => peer_id);
let (tx, rx) = mpsc::channel(10);
asker.behaviour_mut().request(peer_id, request, tx);
task!(asker);
f(rx).await;
});
}
fn fake_setup<F, Fut>(bytes: &[u8], f: F)
where
F: FnOnce(Receiver<Response<String>>) -> Fut + Send + 'static,
Fut: Future,
{
setup_logger();
let rt = Runtime::new().unwrap();
let mut asker = test_swarm(None);
let mut responder = fake_swarm(&rt, bytes);
rt.block_on(async move {
responder
.listen_on(Multiaddr::empty().with(Protocol::Memory(0)))
.unwrap();
let addr = wait4!(responder, SwarmEvent::NewListenAddr{ address, .. } => address);
task!(responder);
asker.dial(addr).unwrap();
let peer_id = wait4!(asker, SwarmEvent::ConnectionEstablished { peer_id, .. } => peer_id);
let (tx, rx) = mpsc::channel(10);
asker.behaviour_mut().request(peer_id, "request".to_owned(), tx);
task!(asker);
f(rx).await;
});
}
#[test]
fn err_size() {
fake_setup(b"zzzz", |mut rx| async move {
assert_eq!(
rx.next().await,
Some(Response::Error(ProtocolError::MessageTooLargeRecv(2054847098)))
);
});
}
#[test]
fn err_nothing() {
fake_setup(b"", |mut rx| async move {
assert_eq!(dbg(rx.next().await.unwrap()), "Error(Io(Kind(UnexpectedEof)))");
});
}
#[test]
fn err_incomplete() {
fake_setup(b"\0\0\0\x05dabcd\0\0\0\x10abcd", |mut rx| async move {
assert_eq!(rx.next().await, Some(Response::Msg("abcd".to_owned())));
assert_eq!(dbg(rx.next().await.unwrap()), "Error(Io(Kind(UnexpectedEof)))");
});
}
#[test]
fn err_no_finish() {
fake_setup(b"\0\0\0\x05dabcd", |mut rx| async move {
assert_eq!(rx.next().await, Some(Response::Msg("abcd".to_owned())));
assert_eq!(dbg(rx.next().await.unwrap()), "Error(Io(Kind(UnexpectedEof)))");
});
}
#[test]
fn err_deser() {
fake_setup(b"\0\0\0\x04abcd", |mut rx| async move {
assert_eq!(
dbg(rx.next().await),
"Some(Error(Serde(ErrorImpl { code: TrailingData, offset: 3 })))"
);
});
}
#[test]
fn err_response_size() {
test_setup(
"123456789012345678901234567890123456789012345678901234567890".to_owned(),
|mut request, peer_id, mut channel| {
tokio::spawn(async move {
request.push_str(&*peer_id.to_string());
channel.feed(request).await.unwrap();
});
},
|mut rx| async move {
assert_eq!(
rx.next().await,
Some(Response::Error(ProtocolError::MessageTooLargeSent(0)))
);
},
);
}
#[test]
fn err_request_size() {
test_setup(
"1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890"
.to_owned(),
|mut request, peer_id, mut channel| {
tokio::spawn(async move {
request.push_str(&*peer_id.to_string());
channel.feed(request).await.unwrap();
});
},
|mut rx| async move {
assert_eq!(
rx.next().await,
Some(Response::Error(ProtocolError::MessageTooLargeSent(102)))
);
},
);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment