Created
May 22, 2022 10:12
-
-
Save rkuhn/413aa0cb4f7415bbb10c3cddd1fa0615 to your computer and use it in GitHub Desktop.
A streaming-response protocol for rust-libp2p
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use super::{ | |
handler::{self, IntoHandler, Request, Response}, | |
RequestReceived, StreamingResponseConfig, | |
}; | |
use crate::Codec; | |
use futures::channel::mpsc; | |
use libp2p::{ | |
core::connection::ConnectionId, | |
swarm::{NetworkBehaviour, NetworkBehaviourAction, NotifyHandler, PollParameters}, | |
PeerId, | |
}; | |
use std::{ | |
collections::VecDeque, | |
marker::PhantomData, | |
task::{Context, Poll}, | |
}; | |
pub struct StreamingResponse<T: Codec + Send + 'static> { | |
config: StreamingResponseConfig, | |
events: VecDeque<RequestReceived<T>>, | |
requests: VecDeque<NetworkBehaviourAction<RequestReceived<T>, IntoHandler<T>>>, | |
_ph: PhantomData<T>, | |
} | |
impl<T: Codec + Send + 'static> StreamingResponse<T> { | |
pub fn new(config: StreamingResponseConfig) -> Self { | |
Self { | |
config, | |
events: VecDeque::default(), | |
requests: VecDeque::default(), | |
_ph: PhantomData, | |
} | |
} | |
pub fn request(&mut self, peer_id: PeerId, request: T::Request, channel: mpsc::Sender<Response<T::Response>>) { | |
self.requests.push_back(NetworkBehaviourAction::NotifyHandler { | |
peer_id, | |
handler: NotifyHandler::Any, | |
event: Request::new(request, channel), | |
}) | |
} | |
} | |
impl<T: Codec + Send + 'static> NetworkBehaviour for StreamingResponse<T> { | |
type ProtocolsHandler = IntoHandler<T>; | |
type OutEvent = RequestReceived<T>; | |
fn new_handler(&mut self) -> Self::ProtocolsHandler { | |
IntoHandler::new( | |
self.config.spawner.clone(), | |
self.config.max_message_size, | |
self.config.request_timeout, | |
self.config.response_send_buffer_size, | |
self.config.keep_alive, | |
) | |
} | |
fn inject_event( | |
&mut self, | |
peer_id: PeerId, | |
connection: ConnectionId, | |
event: <<Self::ProtocolsHandler as libp2p::swarm::IntoProtocolsHandler>::Handler as libp2p::swarm::ProtocolsHandler>::OutEvent, | |
) { | |
let handler::RequestReceived { request, channel } = event; | |
log::trace!("request received by behaviour: {:?}", request); | |
self.events.push_back(RequestReceived { | |
peer_id, | |
connection, | |
request, | |
channel, | |
}); | |
} | |
fn poll( | |
&mut self, | |
_cx: &mut Context<'_>, | |
_params: &mut impl PollParameters, | |
) -> Poll<NetworkBehaviourAction<Self::OutEvent, Self::ProtocolsHandler>> { | |
if let Some(action) = self.requests.pop_front() { | |
log::trace!("triggering request action"); | |
return Poll::Ready(action); | |
} | |
match self.events.pop_front() { | |
Some(e) => Poll::Ready(NetworkBehaviourAction::GenerateEvent(e)), | |
None => Poll::Pending, | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use super::{ | |
protocol::{self, Requester, Responder}, | |
ProtocolError, | |
}; | |
use crate::Codec; | |
use futures::{ | |
channel::mpsc, future::BoxFuture, stream::FuturesUnordered, AsyncWriteExt, FutureExt, SinkExt, StreamExt, | |
}; | |
use libp2p::{ | |
core::{ConnectedPoint, UpgradeError}, | |
swarm::{ | |
protocols_handler::{InboundUpgradeSend, OutboundUpgradeSend}, | |
IntoProtocolsHandler, KeepAlive, ProtocolsHandler, ProtocolsHandlerEvent, ProtocolsHandlerUpgrErr, | |
SubstreamProtocol, | |
}, | |
PeerId, | |
}; | |
use std::{ | |
any::Any, | |
collections::VecDeque, | |
fmt::Debug, | |
io::ErrorKind, | |
marker::PhantomData, | |
sync::Arc, | |
task::{Context, Poll}, | |
time::Duration, | |
}; | |
#[derive(Debug, PartialEq)] | |
pub enum Response<T> { | |
Msg(T), | |
Error(ProtocolError), | |
Finished, | |
} | |
impl<T> Response<T> { | |
pub fn into_msg(self) -> Result<T, ProtocolError> { | |
match self { | |
Response::Msg(msg) => Ok(msg), | |
Response::Error(e) => Err(e), | |
Response::Finished => Err(ProtocolError::Io(ErrorKind::UnexpectedEof.into())), | |
} | |
} | |
} | |
pub struct Request<T: Codec> { | |
request: T::Request, | |
channel: mpsc::Sender<Response<T::Response>>, | |
} | |
impl<T: Codec> Request<T> { | |
pub fn new(request: T::Request, channel: mpsc::Sender<Response<T::Response>>) -> Self { | |
Self { request, channel } | |
} | |
} | |
impl<T: Codec> Debug for Request<T> { | |
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | |
f.debug_struct("Request").field("request", &self.request).finish() | |
} | |
} | |
pub struct RequestReceived<T: Codec> { | |
pub(crate) request: T::Request, | |
pub(crate) channel: mpsc::Sender<T::Response>, | |
} | |
impl<T: Codec> Debug for RequestReceived<T> { | |
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | |
f.debug_struct("RequestReceived") | |
.field("request", &self.request) | |
.finish() | |
} | |
} | |
pub struct IntoHandler<T> { | |
spawner: Spawner, | |
max_message_size: u32, | |
request_timeout: Duration, | |
response_send_buffer_size: usize, | |
keep_alive: bool, | |
_ph: PhantomData<T>, | |
} | |
impl<T> IntoHandler<T> { | |
pub fn new( | |
spawner: Spawner, | |
max_message_size: u32, | |
request_timeout: Duration, | |
response_send_buffer_size: usize, | |
keep_alive: bool, | |
) -> Self { | |
Self { | |
spawner, | |
max_message_size, | |
request_timeout, | |
response_send_buffer_size, | |
keep_alive, | |
_ph: PhantomData, | |
} | |
} | |
} | |
impl<T: Codec + Send + 'static> IntoProtocolsHandler for IntoHandler<T> { | |
type Handler = Handler<T>; | |
fn into_handler(self, _remote_peer_id: &PeerId, _connected_point: &ConnectedPoint) -> Self::Handler { | |
Handler::new( | |
self.spawner, | |
self.max_message_size, | |
self.request_timeout, | |
self.response_send_buffer_size, | |
self.keep_alive, | |
) | |
} | |
fn inbound_protocol(&self) -> <Self::Handler as ProtocolsHandler>::InboundProtocol { | |
Responder::new(self.max_message_size) | |
} | |
} | |
type ProtocolEvent<T> = ProtocolsHandlerEvent< | |
Requester<T>, | |
mpsc::Sender<Response<<T as Codec>::Response>>, | |
RequestReceived<T>, | |
ProtocolError, | |
>; | |
pub type ResponseFuture = BoxFuture<'static, Box<dyn Any + Send + 'static>>; | |
pub type Spawner = Arc<dyn Fn(ResponseFuture) -> ResponseFuture + Send + Sync + 'static>; | |
pub struct Handler<T: Codec> { | |
events: VecDeque<ProtocolEvent<T>>, | |
streams: FuturesUnordered<ResponseFuture>, | |
spawner: Spawner, | |
max_message_size: u32, | |
request_timeout: Duration, | |
response_send_buffer_size: usize, | |
keep_alive: bool, | |
} | |
impl<T: Codec> Debug for Handler<T> { | |
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | |
f.debug_struct("Handler") | |
.field("events", &self.events.len()) | |
.field("streams", &self.streams.len()) | |
.finish() | |
} | |
} | |
impl<T: Codec> Handler<T> { | |
pub fn new( | |
spawner: Spawner, | |
max_message_size: u32, | |
request_timeout: Duration, | |
response_send_buffer_size: usize, | |
keep_alive: bool, | |
) -> Self { | |
Self { | |
events: VecDeque::default(), | |
streams: FuturesUnordered::default(), | |
spawner, | |
max_message_size, | |
request_timeout, | |
response_send_buffer_size, | |
keep_alive, | |
} | |
} | |
} | |
impl<T: Codec + Send + 'static> ProtocolsHandler for Handler<T> { | |
type InEvent = Request<T>; | |
type OutEvent = RequestReceived<T>; | |
type Error = ProtocolError; | |
type InboundProtocol = Responder<T>; | |
type OutboundProtocol = Requester<T>; | |
type InboundOpenInfo = (); | |
type OutboundOpenInfo = mpsc::Sender<Response<T::Response>>; | |
fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> { | |
SubstreamProtocol::new(Responder::new(self.max_message_size), ()).with_timeout(self.request_timeout) | |
} | |
fn inject_fully_negotiated_inbound( | |
&mut self, | |
protocol: <Self::InboundProtocol as InboundUpgradeSend>::Output, | |
_info: Self::InboundOpenInfo, | |
) { | |
let (request, mut stream) = protocol; | |
let (channel, mut rx) = mpsc::channel(self.response_send_buffer_size); | |
let max_message_size = self.max_message_size; | |
log::trace!("handler received request"); | |
let task = (self.spawner)( | |
async move { | |
log::trace!("starting send loop"); | |
let mut buffer = Vec::new(); | |
loop { | |
// only flush once we’re going to sleep | |
let response = match rx.try_next() { | |
Ok(Some(r)) => r, | |
Ok(None) => break, | |
Err(_) => { | |
log::trace!("flushing stream"); | |
stream.flush().await?; | |
match rx.next().await { | |
Some(r) => r, | |
None => break, | |
} | |
} | |
}; | |
protocol::write_msg(&mut stream, response, max_message_size, &mut buffer).await?; | |
} | |
log::trace!("flushing and closing substream"); | |
protocol::write_finish(&mut stream).await?; | |
Result::<_, ProtocolError>::Ok(()) | |
} | |
.map(|res| -> Box<dyn Any + Send + 'static> { Box::new(res) }) | |
.boxed(), | |
); | |
self.streams.push(task); | |
self.events | |
.push_back(ProtocolsHandlerEvent::Custom(RequestReceived { request, channel })); | |
} | |
fn inject_fully_negotiated_outbound( | |
&mut self, | |
mut stream: <Self::OutboundProtocol as OutboundUpgradeSend>::Output, | |
mut tx: Self::OutboundOpenInfo, | |
) { | |
let max_message_size = self.max_message_size; | |
let task = (self.spawner)( | |
async move { | |
log::trace!("starting receive loop"); | |
let mut buffer = Vec::new(); | |
loop { | |
match protocol::read_msg(&mut stream, max_message_size, &mut buffer) | |
.await | |
.unwrap_or_else(Response::Error) | |
{ | |
Response::Msg(msg) => { | |
tx.feed(Response::Msg(msg)).await?; | |
log::trace!("response sent to client code"); | |
} | |
Response::Error(e) => { | |
log::debug!("sending substream error {}", e); | |
tx.feed(Response::Error(e)).await?; | |
return Result::<_, ProtocolError>::Ok(()); | |
} | |
Response::Finished => { | |
log::trace!("finishing substream"); | |
tx.feed(Response::Finished).await?; | |
return Ok(()); | |
} | |
} | |
} | |
} | |
.map(|res| -> Box<dyn Any + Send + 'static> { Box::new(res) }) | |
.boxed(), | |
); | |
self.streams.push(task); | |
} | |
fn inject_event(&mut self, command: Self::InEvent) { | |
let Request { request, channel } = command; | |
log::trace!("requesting {:?}", request); | |
self.events.push_back(ProtocolsHandlerEvent::OutboundSubstreamRequest { | |
protocol: SubstreamProtocol::new(Requester::new(self.max_message_size, request), channel) | |
.with_timeout(self.request_timeout), | |
}) | |
} | |
fn inject_dial_upgrade_error( | |
&mut self, | |
mut tx: Self::OutboundOpenInfo, | |
error: ProtocolsHandlerUpgrErr<<Self::OutboundProtocol as OutboundUpgradeSend>::Error>, | |
) { | |
let error = match error { | |
ProtocolsHandlerUpgrErr::Timeout => ProtocolError::Timeout, | |
ProtocolsHandlerUpgrErr::Timer => ProtocolError::Timeout, | |
ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(e)) => e, | |
ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e)) => e.into(), | |
}; | |
log::debug!("dial upgrade error: {}", error); | |
if let Err(Response::Error(e)) = tx.try_send(Response::Error(error)).map_err(|e| e.into_inner()) { | |
log::warn!("cannot send upgrade error to requester: {}", e); | |
} | |
} | |
fn connection_keep_alive(&self) -> KeepAlive { | |
if !self.keep_alive && self.streams.is_empty() { | |
KeepAlive::No | |
} else { | |
KeepAlive::Yes | |
} | |
} | |
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<ProtocolEvent<T>> { | |
loop { | |
if self.streams.is_empty() { | |
break; | |
} | |
if let Poll::Ready(result) = self.streams.poll_next_unpin(cx) { | |
// since the set was not empty, this must be a Some() | |
if let Some(Err(e)) = result.and_then(|e| e.downcast::<Result<(), ProtocolError>>().ok().map(|b| *b)) { | |
// no need to tear down the connection, substream is already closed | |
log::warn!("error in substream task: {}", e); | |
} | |
} else { | |
break; | |
} | |
} | |
match self.events.pop_front() { | |
Some(e) => Poll::Ready(e), | |
None => Poll::Pending, | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
mod behaviour; | |
mod handler; | |
mod protocol; | |
#[cfg(test)] | |
mod tests; | |
pub use behaviour::StreamingResponse; | |
pub use handler::{Response, ResponseFuture, Spawner}; | |
pub use protocol::ProtocolError; | |
use crate::Codec; | |
use futures::channel::mpsc; | |
use libp2p::{core::connection::ConnectionId, PeerId}; | |
use std::{fmt::Debug, sync::Arc, time::Duration}; | |
pub struct RequestReceived<T: Codec> { | |
pub peer_id: PeerId, | |
pub connection: ConnectionId, | |
pub request: T::Request, | |
pub channel: mpsc::Sender<T::Response>, | |
} | |
impl<T: Codec> Debug for RequestReceived<T> { | |
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | |
f.debug_struct("RequestReceived") | |
.field("peer_id", &self.peer_id) | |
.field("connection", &self.connection) | |
.field("request", &self.request) | |
.finish() | |
} | |
} | |
pub struct StreamingResponseConfig { | |
pub spawner: Spawner, | |
pub request_timeout: Duration, | |
pub max_message_size: u32, | |
pub response_send_buffer_size: usize, | |
pub keep_alive: bool, | |
} | |
impl StreamingResponseConfig { | |
/// Spawn response stream handling tasks using the given function | |
/// | |
/// This function may be called from an arbitrary context, you cannot assume that because | |
/// you’re using Tokio this will happen on a Tokio thread. Hence it is necessary to point | |
/// to the target thread pool directly, e.g. by using a runtime handle. | |
/// | |
/// If this method is not used, tasks will be polled via the Swarm, which may be an I/O | |
/// bottleneck. | |
pub fn with_spawner(self, spawner: impl Fn(ResponseFuture) -> ResponseFuture + Send + Sync + 'static) -> Self { | |
Self { | |
spawner: Arc::new(spawner), | |
..self | |
} | |
} | |
/// Timeout for the transmission of the request to the peer, default is 10sec | |
pub fn with_request_timeout(self, request_timeout: Duration) -> Self { | |
Self { | |
request_timeout, | |
..self | |
} | |
} | |
/// Maximum message size permitted for requests and responses | |
/// | |
/// The maximum is 4GiB, the default 1MB. Sending huge messages requires corresponding | |
/// buffers and may not be desirable. | |
pub fn with_max_message_size(self, max_message_size: u32) -> Self { | |
Self { | |
max_message_size, | |
..self | |
} | |
} | |
/// Set the queue size in messages for the channel created for incoming requests | |
/// | |
/// All channels are bounded in size and use back-pressure. This channel size allows some | |
/// decoupling between response generation and network transmission. Default is 128. | |
pub fn with_response_send_buffer_size(self, response_send_buffer_size: usize) -> Self { | |
Self { | |
response_send_buffer_size, | |
..self | |
} | |
} | |
/// If this is set to true, then this behaviour will keep the connection alive | |
/// | |
/// Otherwise the connection is released (i.e. closed if no other behaviour keeps it alive) | |
/// when there are no active requests ongoing. Default is `false`. | |
pub fn with_keep_alive(self, keep_alive: bool) -> Self { | |
Self { keep_alive, ..self } | |
} | |
} | |
impl Default for StreamingResponseConfig { | |
fn default() -> Self { | |
Self { | |
spawner: Arc::new(|f| f), | |
request_timeout: Duration::from_secs(10), | |
max_message_size: 1_000_000, | |
response_send_buffer_size: 128, | |
keep_alive: false, | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use super::handler::Response; | |
use crate::Codec; | |
use derive_more::{Display, Error, From}; | |
use futures::{channel::mpsc, future::BoxFuture, AsyncReadExt, AsyncWriteExt, FutureExt}; | |
use libp2p::{ | |
core::{upgrade::NegotiationError, UpgradeInfo}, | |
swarm::NegotiatedSubstream, | |
InboundUpgrade, OutboundUpgrade, | |
}; | |
use serde::de::DeserializeOwned; | |
use std::{ | |
fmt::{Display, Write}, | |
io::ErrorKind, | |
iter::{once, Once}, | |
marker::PhantomData, | |
}; | |
#[derive(Error, Display, Debug, From)] | |
pub enum ProtocolError { | |
#[display(fmt = "timeout while waiting for request receive confirmation")] | |
Timeout, | |
#[display(fmt = "message too large received: {}", _0)] | |
#[from(ignore)] | |
MessageTooLargeRecv(#[error(ignore)] usize), | |
#[display(fmt = "message too large sent: {}", _0)] | |
#[from(ignore)] | |
MessageTooLargeSent(#[error(ignore)] usize), | |
#[display(fmt = "substream protocol negotiation error: {}", _0)] | |
Negotiation(NegotiationError), | |
#[display(fmt = "I/O error: {}", _0)] | |
Io(std::io::Error), | |
#[display(fmt = "(de)serialisation error: {}", _0)] | |
Serde(serde_cbor::Error), | |
#[display(fmt = "internal channel error")] | |
Channel(mpsc::SendError), | |
/// This variant is useful for implementing the function to pass to | |
/// [`with_spawner`](crate::v2::StreamingResponseConfig) | |
#[display(fmt = "spawned task failed (cancelled={})", _0)] | |
JoinError(#[error(ignore)] bool), | |
} | |
impl PartialEq for ProtocolError { | |
fn eq(&self, other: &Self) -> bool { | |
match (self, other) { | |
(Self::MessageTooLargeRecv(l0), Self::MessageTooLargeRecv(r0)) => l0 == r0, | |
(Self::MessageTooLargeSent(l0), Self::MessageTooLargeSent(r0)) => l0 == r0, | |
(Self::Negotiation(l0), Self::Negotiation(r0)) => l0.to_string() == r0.to_string(), | |
(Self::Io(l0), Self::Io(r0)) => l0.to_string() == r0.to_string(), | |
(Self::Serde(l0), Self::Serde(r0)) => l0.to_string() == r0.to_string(), | |
(Self::Channel(l0), Self::Channel(r0)) => l0 == r0, | |
(Self::JoinError(l0), Self::JoinError(r0)) => l0 == r0, | |
_ => core::mem::discriminant(self) == core::mem::discriminant(other), | |
} | |
} | |
} | |
impl ProtocolError { | |
pub fn as_code(&self) -> u8 { | |
match self { | |
ProtocolError::Timeout => 1, | |
ProtocolError::MessageTooLargeRecv(_) => 2, | |
ProtocolError::MessageTooLargeSent(_) => 3, | |
ProtocolError::Negotiation(_) => 4, | |
ProtocolError::Io(_) => 5, | |
ProtocolError::Serde(_) => 6, | |
ProtocolError::Channel(_) => 7, | |
ProtocolError::JoinError(_) => 8, | |
} | |
} | |
pub fn from_code(code: u8) -> Self { | |
match code { | |
1 => ProtocolError::Timeout, | |
2 => ProtocolError::MessageTooLargeRecv(0), | |
3 => ProtocolError::MessageTooLargeSent(0), | |
4 => ProtocolError::Negotiation(NegotiationError::Failed), | |
5 => ProtocolError::Io(std::io::Error::new(ErrorKind::Other, "some error on peer")), | |
6 => ProtocolError::Serde(std::io::Error::new(ErrorKind::Other, "serde error on peer").into()), | |
7 => { | |
let (mut tx, _) = mpsc::channel(1); | |
let err = tx.try_send(0).unwrap_err().into_send_error(); | |
ProtocolError::Channel(err) | |
} | |
8 => ProtocolError::JoinError(false), | |
n => ProtocolError::Io(std::io::Error::new( | |
ErrorKind::Other, | |
format!("unknown error code {}", n), | |
)), | |
} | |
} | |
} | |
pub async fn write_msg( | |
io: &mut NegotiatedSubstream, | |
msg: impl serde::Serialize, | |
max_size: u32, | |
buffer: &mut Vec<u8>, | |
) -> Result<(), ProtocolError> { | |
buffer.resize(4, 0); | |
let res = serde_cbor::to_writer(&mut *buffer, &msg); | |
if let Err(e) = res { | |
let err = ProtocolError::Serde(e); | |
write_err(io, &err).await?; | |
return Err(err); | |
} | |
let size = buffer.len() - 4; | |
if size > (max_size as usize) { | |
log::debug!("message size {} too large (max = {})", size, max_size); | |
let err = ProtocolError::MessageTooLargeSent(size); | |
write_err(io, &err).await?; | |
return Err(err); | |
} | |
log::trace!("sending message of size {}", size); | |
buffer.as_mut_slice()[..4].copy_from_slice(&(size as u32).to_be_bytes()); | |
io.write_all(buffer.as_slice()).await?; | |
Ok(()) | |
} | |
pub async fn write_err(io: &mut NegotiatedSubstream, err: &ProtocolError) -> Result<(), std::io::Error> { | |
let buf = [255, err.as_code()]; | |
io.write_all(&buf).await?; | |
io.flush().await?; | |
io.close().await?; | |
Ok(()) | |
} | |
pub async fn write_finish(io: &mut NegotiatedSubstream) -> Result<(), std::io::Error> { | |
let buf = [255, 0]; | |
io.write_all(&buf).await?; | |
io.flush().await?; | |
io.close().await?; | |
Ok(()) | |
} | |
pub async fn read_msg<T: DeserializeOwned>( | |
io: &mut NegotiatedSubstream, | |
max_size: u32, | |
buffer: &mut Vec<u8>, | |
) -> Result<Response<T>, ProtocolError> { | |
let mut size_bytes = [0u8; 4]; | |
let mut to_read = &mut size_bytes[..]; | |
while !to_read.is_empty() { | |
let read = io.read(to_read).await?; | |
log::trace!("read {} header bytes", read); | |
if read == 0 { | |
let len = to_read.len(); | |
let read = &size_bytes[..4 - len]; | |
if read.len() != 2 || read[0] != 255 { | |
return Err(ProtocolError::Io(ErrorKind::UnexpectedEof.into())); | |
} else { | |
return match read[1] { | |
0 => Ok(Response::Finished), | |
n => Err(ProtocolError::from_code(n)), | |
}; | |
} | |
} | |
to_read = to_read.split_at_mut(read).1; | |
} | |
let size = u32::from_be_bytes(size_bytes); | |
if size > max_size { | |
log::debug!("message size {} too large (max = {})", size, max_size); | |
let mut bytes = [0u8; 4096]; | |
bytes[..4].copy_from_slice(&size_bytes); | |
let n = io.read(&mut bytes[4..]).await?; | |
log::debug!("{:?}", &bytes[..n + 4]); | |
return Err(ProtocolError::MessageTooLargeRecv(size as usize)); | |
} | |
log::trace!("received header: msg is {} bytes", size); | |
buffer.resize(size as usize, 0); | |
io.read_exact(buffer.as_mut_slice()).await?; | |
log::trace!("all bytes read"); | |
Ok(Response::Msg(serde_cbor::from_slice(buffer.as_slice())?)) | |
} | |
#[derive(Debug)] | |
pub struct Responder<T> { | |
max_message_size: u32, | |
_ph: PhantomData<T>, | |
} | |
impl<T> Responder<T> { | |
pub fn new(max_message_size: u32) -> Self { | |
Self { | |
max_message_size, | |
_ph: PhantomData, | |
} | |
} | |
} | |
impl<T: Codec> UpgradeInfo for Responder<T> { | |
type Info = &'static [u8]; | |
type InfoIter = Once<&'static [u8]>; | |
fn protocol_info(&self) -> Self::InfoIter { | |
once(T::protocol_info()) | |
} | |
} | |
struct ProtoNameDisplay(&'static [u8]); | |
impl Display for ProtoNameDisplay { | |
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | |
for byte in self.0 { | |
if *byte > 31 && *byte < 128 { | |
f.write_char((*byte).into())?; | |
} else { | |
f.write_char('\u{fffd}')?; | |
} | |
} | |
Ok(()) | |
} | |
} | |
impl<T: Codec> InboundUpgrade<NegotiatedSubstream> for Responder<T> { | |
type Output = (T::Request, NegotiatedSubstream); | |
type Error = ProtocolError; | |
type Future = BoxFuture<'static, Result<Self::Output, Self::Error>>; | |
fn upgrade_inbound(self, mut socket: NegotiatedSubstream, info: Self::Info) -> Self::Future { | |
let max_message_size = self.max_message_size; | |
async move { | |
log::trace!("starting inbound upgrade `{}`", ProtoNameDisplay(info)); | |
let msg = read_msg(&mut socket, max_message_size, &mut Vec::new()) | |
.await? | |
.into_msg()?; | |
log::trace!("request received: {:?}", msg); | |
Ok((msg, socket)) | |
} | |
.boxed() | |
} | |
} | |
#[derive(Debug)] | |
pub struct Requester<T: Codec> { | |
max_message_size: u32, | |
request: T::Request, | |
} | |
impl<T: Codec> Requester<T> { | |
pub fn new(max_message_size: u32, request: T::Request) -> Self { | |
Self { | |
max_message_size, | |
request, | |
} | |
} | |
} | |
impl<T: Codec> UpgradeInfo for Requester<T> { | |
type Info = &'static [u8]; | |
type InfoIter = Once<&'static [u8]>; | |
fn protocol_info(&self) -> Self::InfoIter { | |
once(T::protocol_info()) | |
} | |
} | |
impl<T: Codec> OutboundUpgrade<NegotiatedSubstream> for Requester<T> { | |
type Output = NegotiatedSubstream; | |
type Error = ProtocolError; | |
type Future = BoxFuture<'static, Result<Self::Output, Self::Error>>; | |
fn upgrade_outbound(self, mut socket: NegotiatedSubstream, info: Self::Info) -> Self::Future { | |
let Self { | |
max_message_size, | |
request, | |
} = self; | |
async move { | |
log::trace!("starting output upgrade `{}`", ProtoNameDisplay(info)); | |
write_msg(&mut socket, request, max_message_size, &mut Vec::new()).await?; | |
socket.flush().await?; | |
log::trace!("all bytes sent"); | |
Ok(socket) | |
} | |
.boxed() | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use super::{ProtocolError, StreamingResponse, StreamingResponseConfig}; | |
use crate::{ | |
v2::{handler::Response, RequestReceived}, | |
Codec, | |
}; | |
use futures::{ | |
channel::mpsc::{self, Receiver, Sender}, | |
Future, FutureExt, SinkExt, StreamExt, | |
}; | |
use libp2p::{ | |
core::{transport::MemoryTransport, upgrade::Version}, | |
identity::Keypair, | |
mplex::MplexConfig, | |
multiaddr::Protocol, | |
plaintext::PlainText2Config, | |
swarm::{SwarmBuilder, SwarmEvent}, | |
Multiaddr, PeerId, Swarm, Transport, | |
}; | |
use tokio::runtime::{Handle, Runtime}; | |
use tracing_subscriber::{fmt::format::FmtSpan, util::SubscriberInitExt, EnvFilter}; | |
mod proto; | |
const PROTO: &[u8] = b"/my/test"; | |
fn test_swarm(use_spawner: Option<Handle>) -> Swarm<StreamingResponse<Proto>> { | |
let local_key = Keypair::generate_ed25519(); | |
let local_public_key = local_key.public(); | |
let local_peer_id = local_public_key.clone().into(); | |
let transport = MemoryTransport::default() | |
.upgrade(Version::V1) | |
.authenticate(PlainText2Config { local_public_key }) | |
.multiplex(MplexConfig::new()) | |
.boxed(); | |
let mut config = StreamingResponseConfig::default() | |
.with_keep_alive(true) | |
.with_max_message_size(100); | |
#[allow(clippy::redundant_closure)] | |
if let Some(rt) = use_spawner { | |
config = config.with_spawner(move |f| rt.spawn(f).map(|r| r.unwrap_or_else(|e| Box::new(e))).boxed()); | |
} | |
let behaviour = StreamingResponse::new(config); | |
SwarmBuilder::new(transport, behaviour, local_peer_id).build() | |
} | |
fn fake_swarm(rt: &Runtime, bytes: &[u8]) -> Swarm<proto::TestBehaviour> { | |
let local_key = Keypair::generate_ed25519(); | |
let local_public_key = local_key.public(); | |
let local_peer_id = local_public_key.clone().into(); | |
let transport = MemoryTransport::default() | |
.upgrade(Version::V1) | |
.authenticate(PlainText2Config { local_public_key }) | |
.multiplex(MplexConfig::new()) | |
.boxed(); | |
let behaviour = proto::TestBehaviour(rt.handle().clone(), bytes.to_owned()); | |
SwarmBuilder::new(transport, behaviour, local_peer_id).build() | |
} | |
struct Proto; | |
impl Codec for Proto { | |
type Request = String; | |
type Response = String; | |
fn protocol_info() -> &'static [u8] { | |
PROTO | |
} | |
} | |
macro_rules! wait4 { | |
($s:ident, $p:pat => $e:expr) => { | |
loop { | |
let ev = $s.next().await; | |
if ev.is_none() { | |
panic!("{} STOPPED", stringify!($s)) | |
} | |
let ev = ev.unwrap(); | |
log::info!("{} got {:?}", stringify!($s), ev); | |
if let $p = ev { | |
break $e; | |
} | |
} | |
}; | |
} | |
macro_rules! task { | |
($s:ident $(, $p:pat => $e:expr)*) => { | |
tokio::spawn(async move { | |
while let Some(ev) = $s.next().await { | |
log::info!("{} got {:?}", stringify!($s), ev); | |
match ev { | |
$($p => ($e),)* | |
_ => {} | |
} | |
} | |
log::info!("{} STOPPED", stringify!($s)); | |
}) | |
}; | |
} | |
fn dbg<T: std::fmt::Debug>(x: T) -> String { | |
format!("{:?}", x) | |
} | |
fn setup_logger() { | |
tracing_subscriber::fmt() | |
.with_env_filter(EnvFilter::from_default_env()) | |
.with_span_events(FmtSpan::ENTER | FmtSpan::CLOSE) | |
.finish() | |
.try_init() | |
.ok(); | |
} | |
#[test] | |
fn smoke() { | |
setup_logger(); | |
let rt = Runtime::new().unwrap(); | |
let mut asker = test_swarm(None); | |
let asker_id = *asker.local_peer_id(); | |
let mut responder = test_swarm(None); | |
let responder_id = *responder.local_peer_id(); | |
asker.listen_on(Multiaddr::empty().with(Protocol::Memory(0))).unwrap(); | |
rt.block_on(async move { | |
let addr = wait4!(asker, SwarmEvent::NewListenAddr { address, .. } => address); | |
responder.dial(addr).unwrap(); | |
task!(responder, | |
SwarmEvent::Behaviour(RequestReceived { request, peer_id, mut channel, .. }) => { | |
tokio::spawn(async move { | |
channel.feed(request).await.unwrap(); | |
channel.feed(peer_id.to_string()).await.unwrap(); | |
channel.close().await.unwrap(); | |
}); | |
} | |
); | |
let peer_id = wait4!(asker, SwarmEvent::ConnectionEstablished { peer_id, .. } => peer_id); | |
assert_eq!(peer_id, responder_id); | |
let (tx, rx) = mpsc::channel(10); | |
asker.behaviour_mut().request(peer_id, "request".to_owned(), tx); | |
task!(asker); | |
let response = rx | |
.map(|r| match r { | |
Response::Msg(m) => Some(m), | |
Response::Error(e) => panic!("got error: {:#}", e), | |
Response::Finished => None, | |
}) | |
.collect::<Vec<_>>() | |
.await; | |
assert_eq!( | |
response, | |
vec![Some("request".to_owned()), Some(asker_id.to_string()), None] | |
); | |
}); | |
} | |
#[test] | |
fn smoke_executor() { | |
setup_logger(); | |
let rt = Runtime::new().unwrap(); | |
let mut asker = test_swarm(Some(rt.handle().clone())); | |
let asker_id = *asker.local_peer_id(); | |
let mut responder = test_swarm(Some(rt.handle().clone())); | |
let responder_id = *responder.local_peer_id(); | |
asker.listen_on(Multiaddr::empty().with(Protocol::Memory(0))).unwrap(); | |
rt.block_on(async move { | |
let addr = wait4!(asker, SwarmEvent::NewListenAddr { address, .. } => address); | |
responder.dial(addr).unwrap(); | |
task!(responder, | |
SwarmEvent::Behaviour(RequestReceived { request, peer_id, mut channel, .. }) => { | |
tokio::spawn(async move { | |
channel.feed(request).await.unwrap(); | |
channel.feed(peer_id.to_string()).await.unwrap(); | |
channel.close().await.unwrap(); | |
}); | |
} | |
); | |
let peer_id = wait4!(asker, SwarmEvent::ConnectionEstablished { peer_id, .. } => peer_id); | |
assert_eq!(peer_id, responder_id); | |
let (tx, rx) = mpsc::channel(10); | |
asker.behaviour_mut().request(peer_id, "request".to_owned(), tx); | |
task!(asker); | |
let response = rx | |
.map(|r| match r { | |
Response::Msg(m) => Some(m), | |
Response::Error(e) => panic!("got error: {:#}", e), | |
Response::Finished => None, | |
}) | |
.collect::<Vec<_>>() | |
.await; | |
assert_eq!( | |
response, | |
vec![Some("request".to_owned()), Some(asker_id.to_string()), None] | |
); | |
}); | |
} | |
fn test_setup<F, Fut, L>(request: String, logic: L, f: F) | |
where | |
F: FnOnce(Receiver<Response<String>>) -> Fut + Send + 'static, | |
Fut: Future, | |
L: Fn(String, PeerId, Sender<String>) + Send + 'static, | |
{ | |
setup_logger(); | |
let rt = Runtime::new().unwrap(); | |
let mut asker = test_swarm(None); | |
let mut responder = test_swarm(None); | |
rt.block_on(async move { | |
responder | |
.listen_on(Multiaddr::empty().with(Protocol::Memory(0))) | |
.unwrap(); | |
let addr = wait4!(responder, SwarmEvent::NewListenAddr{ address, .. } => address); | |
task!(responder, SwarmEvent::Behaviour(RequestReceived { request, peer_id, channel, .. }) => logic(request, peer_id, channel)); | |
asker.dial(addr).unwrap(); | |
let peer_id = wait4!(asker, SwarmEvent::ConnectionEstablished { peer_id, .. } => peer_id); | |
let (tx, rx) = mpsc::channel(10); | |
asker.behaviour_mut().request(peer_id, request, tx); | |
task!(asker); | |
f(rx).await; | |
}); | |
} | |
fn fake_setup<F, Fut>(bytes: &[u8], f: F) | |
where | |
F: FnOnce(Receiver<Response<String>>) -> Fut + Send + 'static, | |
Fut: Future, | |
{ | |
setup_logger(); | |
let rt = Runtime::new().unwrap(); | |
let mut asker = test_swarm(None); | |
let mut responder = fake_swarm(&rt, bytes); | |
rt.block_on(async move { | |
responder | |
.listen_on(Multiaddr::empty().with(Protocol::Memory(0))) | |
.unwrap(); | |
let addr = wait4!(responder, SwarmEvent::NewListenAddr{ address, .. } => address); | |
task!(responder); | |
asker.dial(addr).unwrap(); | |
let peer_id = wait4!(asker, SwarmEvent::ConnectionEstablished { peer_id, .. } => peer_id); | |
let (tx, rx) = mpsc::channel(10); | |
asker.behaviour_mut().request(peer_id, "request".to_owned(), tx); | |
task!(asker); | |
f(rx).await; | |
}); | |
} | |
#[test] | |
fn err_size() { | |
fake_setup(b"zzzz", |mut rx| async move { | |
assert_eq!( | |
rx.next().await, | |
Some(Response::Error(ProtocolError::MessageTooLargeRecv(2054847098))) | |
); | |
}); | |
} | |
#[test] | |
fn err_nothing() { | |
fake_setup(b"", |mut rx| async move { | |
assert_eq!(dbg(rx.next().await.unwrap()), "Error(Io(Kind(UnexpectedEof)))"); | |
}); | |
} | |
#[test] | |
fn err_incomplete() { | |
fake_setup(b"\0\0\0\x05dabcd\0\0\0\x10abcd", |mut rx| async move { | |
assert_eq!(rx.next().await, Some(Response::Msg("abcd".to_owned()))); | |
assert_eq!(dbg(rx.next().await.unwrap()), "Error(Io(Kind(UnexpectedEof)))"); | |
}); | |
} | |
#[test] | |
fn err_no_finish() { | |
fake_setup(b"\0\0\0\x05dabcd", |mut rx| async move { | |
assert_eq!(rx.next().await, Some(Response::Msg("abcd".to_owned()))); | |
assert_eq!(dbg(rx.next().await.unwrap()), "Error(Io(Kind(UnexpectedEof)))"); | |
}); | |
} | |
#[test] | |
fn err_deser() { | |
fake_setup(b"\0\0\0\x04abcd", |mut rx| async move { | |
assert_eq!( | |
dbg(rx.next().await), | |
"Some(Error(Serde(ErrorImpl { code: TrailingData, offset: 3 })))" | |
); | |
}); | |
} | |
#[test] | |
fn err_response_size() { | |
test_setup( | |
"123456789012345678901234567890123456789012345678901234567890".to_owned(), | |
|mut request, peer_id, mut channel| { | |
tokio::spawn(async move { | |
request.push_str(&*peer_id.to_string()); | |
channel.feed(request).await.unwrap(); | |
}); | |
}, | |
|mut rx| async move { | |
assert_eq!( | |
rx.next().await, | |
Some(Response::Error(ProtocolError::MessageTooLargeSent(0))) | |
); | |
}, | |
); | |
} | |
#[test] | |
fn err_request_size() { | |
test_setup( | |
"1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890" | |
.to_owned(), | |
|mut request, peer_id, mut channel| { | |
tokio::spawn(async move { | |
request.push_str(&*peer_id.to_string()); | |
channel.feed(request).await.unwrap(); | |
}); | |
}, | |
|mut rx| async move { | |
assert_eq!( | |
rx.next().await, | |
Some(Response::Error(ProtocolError::MessageTooLargeSent(102))) | |
); | |
}, | |
); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment