Skip to content

Instantly share code, notes, and snippets.

@BlinkyStitt
Last active February 8, 2024 21:30
Show Gist options
  • Save BlinkyStitt/837ffdbee2892c8857d192a5f4604b9c to your computer and use it in GitHub Desktop.
Save BlinkyStitt/837ffdbee2892c8857d192a5f4604b9c to your computer and use it in GitHub Desktop.
I'll make this into a proper PR. But figured I would share what I have so far now
use async_trait::async_trait;
use ethers::{
contract::{
multicall_contract::{Call3, Multicall3},
ContractError, Multicall,
},
providers::{Middleware, MiddlewareError},
types::{transaction::eip2718::TypedTransaction, Address, BlockId, Bytes},
};
use futures::{FutureExt, TryFutureExt};
use std::{collections::HashMap, fmt::Debug, sync::Arc, time::Duration};
use thiserror::Error;
use tokio::{
pin, select,
sync::oneshot,
task::JoinHandle,
time::{sleep_until, Instant},
};
use tracing::{debug, error, trace};
pub type MulticallResponse<M> = Result<Bytes, ContractError<M>>;
#[derive(Debug)]
pub struct MulticallMiddleware<M: Middleware> {
inner: Arc<M>,
multicall_address: Address,
task_tx: flume::Sender<PendingAction<M>>,
pub task_handle: JoinHandle<anyhow::Result<()>>,
}
#[derive(Debug)]
pub enum MulticallAction {
Call(TypedTransaction, Option<BlockId>),
// Balance(Address, Option<BlockId>),
// TODO: there are more things
}
struct MulticallMiddlewareTask<M: Middleware> {
multicall: Multicall<M>,
batch_size: usize,
max_wait: Duration,
queue: HashMap<Option<BlockId>, PendingActions<M>>,
rx: flume::Receiver<PendingAction<M>>,
}
#[derive(Error, Debug)]
pub enum MulticallMiddlewareError<M: Middleware> {
#[error("{0}")]
Middleware(M::Error),
#[error("{0}")]
Recv(#[from] oneshot::error::RecvError),
#[error("{0:?}")]
PendingActionSend(#[from] flume::SendError<PendingAction<M>>),
#[error("{0}")]
/// TODO: this doesn't feel right
Contract(#[from] ContractError<M>),
}
impl MulticallAction {
fn block_id(&self) -> &Option<BlockId> {
match self {
MulticallAction::Call(_, block) => block,
// MulticallAction::Balance(_, block) => block,
}
}
}
impl<M: Middleware + 'static> MulticallMiddleware<M> {
pub async fn new(
inner: Arc<M>,
capacity: Option<usize>,
address: Option<Address>,
batch_size: Option<usize>,
max_wait: Option<Duration>,
) -> Result<Self, MulticallMiddlewareError<M>> {
let (tx, rx) = if let Some(capacity) = capacity {
// TODO! be careful with capacity. deadlocks are possible because we use `send`. we could `try_send` and if that fails do a future?
// flume::bounded(capacity)
error!("bounded capacity is currently ignored. it doesn't play well with using the blocking send");
flume::unbounded()
} else {
flume::unbounded()
};
let batch_size = batch_size.unwrap_or(200);
let max_wait = max_wait.unwrap_or_else(|| Duration::from_millis(1));
// TODO: don't unwrap
let multicall = Multicall::<M>::new(inner.clone(), address).await.unwrap();
let address = multicall.contract.address();
let queue = HashMap::new();
let task = MulticallMiddlewareTask {
multicall,
batch_size,
max_wait,
queue,
rx,
};
let task_handle = tokio::spawn(task.run().inspect_err(|e| {
// TODO: i think this needs to be a panic, but i'm not positive
panic!("MulticallMiddlewareTask error: {:?}", e);
}));
let x = Self {
inner,
multicall_address: address,
task_tx: tx,
task_handle,
};
Ok(x)
}
}
impl<M: Middleware + 'static> MiddlewareError for MulticallMiddlewareError<M> {
type Inner = M::Error;
fn from_err(src: M::Error) -> MulticallMiddlewareError<M> {
MulticallMiddlewareError::Middleware(src)
}
fn as_inner(&self) -> Option<&Self::Inner> {
match self {
MulticallMiddlewareError::Middleware(e) => Some(e),
_ => None,
}
}
}
#[async_trait]
impl<M> Middleware for MulticallMiddleware<M>
where
M: Middleware + 'static,
{
type Error = MulticallMiddlewareError<M>;
type Provider = M::Provider;
type Inner = M;
fn inner(&self) -> &M {
&self.inner
}
/// notice that this does not use the async keyword. the oneshot is set up and then a future is returned!
fn call<'life0, 'life1, 'async_trait>(
&'life0 self,
tx: &'life1 TypedTransaction,
block: Option<BlockId>,
) -> ::core::pin::Pin<
Box<
dyn ::core::future::Future<Output = Result<Bytes, Self::Error>>
+ ::core::marker::Send
+ 'async_trait,
>,
>
where
'life0: 'async_trait,
'life1: 'async_trait,
Self: 'async_trait,
{
let tx_to = tx.to().and_then(|x| x.as_address());
// if the call is already a multicall, do not batch it
if tx_to == Some(&self.multicall_address) {
let direct_f = self
.inner
.call(tx, block)
.map_err(MulticallMiddlewareError::from_err);
return direct_f.boxed();
}
// set up a channel for the result. the background task will do the actual querying
let (result_tx, result_rx) = oneshot::channel();
let pending_action = PendingAction {
action: MulticallAction::Call(tx.clone(), block),
result_tx,
};
// be very careful with bounded channels! they can block the tokio runtime!
if let Err(err) = self.task_tx.send(pending_action) {
let err_f = async move { Err(err.into()) };
return err_f.boxed();
}
// TODO: this match feels weird. is it right?
result_rx
.into_future()
.map(|x| match x {
Ok(Ok(x)) => Ok(x),
Ok(Err(e)) => Err(e.into()),
Err(e) => Err(e.into()),
})
.boxed()
}
}
#[derive(Debug)]
pub struct PendingAction<M: Middleware> {
action: MulticallAction,
result_tx: oneshot::Sender<MulticallResponse<M>>,
}
// TODO: this type is way too complex. think about how to re-arrange it
pub struct PendingActions<M: Middleware> {
must_send_by: Instant,
actions: Vec<PendingAction<M>>,
}
impl<M: Middleware> PendingActions<M> {
fn new(timeout: Duration) -> Self {
Self {
must_send_by: Instant::now() + timeout,
actions: vec![],
}
}
}
impl<M: Middleware + 'static> MulticallMiddlewareTask<M> {
/// TODO: this probably needs to be in an Arc
async fn run(mut self) -> anyhow::Result<()> {
let mut batches = 0;
let mut total_calls = 0;
let mut single_item_batches = 0;
loop {
let first = self.rx.recv_async().await?;
let first_id = *first.action.block_id();
{
let first_entry = self
.queue
.entry(first_id)
.or_insert_with(|| PendingActions::new(self.max_wait));
first_entry.actions.push(first);
}
// there might be items still in the queue. we want to use their must_send_by
let must_send_at = self
.queue
.values()
.min_by(|a, b| a.must_send_by.cmp(&b.must_send_by))
.unwrap()
.must_send_by;
let wait_until = sleep_until(must_send_at);
pin!(wait_until);
loop {
select! {
x = self.rx.recv_async() => {
let x = x?;
let block_id = *x.action.block_id();
let entry = self.queue.entry(block_id).or_insert_with(|| PendingActions::new(self.max_wait));
entry.actions.push(x);
let key_len = entry.actions.len();
// TODO: this size should be configurable
if key_len >= self.batch_size {
trace!("size met");
// TODO: breaking here means we drain ALL of them. but we actually only want to drain this key! maybe instead of break, we call flush_id(key)
break;
}
},
_ = &mut wait_until => {
trace!("multicall aged out");
break;
},
};
}
// TODO: don't drain everything. only drain queues that are full or have been waiting for a certain amount of time
for (block_id, pending_actions) in self.queue.drain() {
let multicall_contract = self.multicall.contract.clone();
let mut calls = vec![];
for pending_action in pending_actions.actions.iter() {
match &pending_action.action {
MulticallAction::Call(tx, _) => calls.push(Call3 {
target: *tx.to_addr().unwrap(),
call_data: tx.data().cloned().unwrap_or_else(Bytes::new),
allow_failure: true,
}),
}
}
let new_calls = calls.len();
total_calls += new_calls;
batches += 1;
if new_calls == 1 {
// TODO: call directly?
// log this because it might mean that we aren't looping in a way that allows batching
debug!("single call: {:?}", calls[0]);
single_item_batches += 1;
}
debug!(
"batching {} calls. ({} reduced to {}. {} singles)",
new_calls, total_calls, batches, single_item_batches,
);
let f = multicall_aggregate(multicall_contract, pending_actions, calls, block_id);
tokio::spawn(f);
}
}
}
}
/// spawn this to run the multicall in the background
async fn multicall_aggregate<M: Middleware + 'static>(
multicall_contract: Multicall3<M>,
pending_actions: PendingActions<M>,
calls: Vec<Call3>,
block_id: Option<BlockId>,
) {
let mut aggregate_call = multicall_contract.aggregate_3(calls);
if let Some(x) = block_id {
aggregate_call = aggregate_call.block(x);
}
let results = match aggregate_call.await {
Err(err) => {
// drop the pending_actions. they will resolve with RecvError and can be retried
// TODO: do something else?
error!("multicall aggregate3 failed: {:?}", err);
return;
}
Ok(x) => x,
};
for (pending_action, result) in pending_actions.actions.into_iter().zip(results) {
let result = if result.success {
Ok(result.return_data)
} else {
// TODO: can we use the inner's existing Error::Revert?
// TODO: should this return a ContractError instead?
Err(ContractError::Revert(result.return_data))
};
let _ = pending_action.result_tx.send(result);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment