Skip to content

Instantly share code, notes, and snippets.

@ppamorim

ppamorim/rpc.rs Secret

Created February 3, 2022 11:55
Show Gist options
  • Save ppamorim/8294fde7025cb66bb8639551c8abac2d to your computer and use it in GitHub Desktop.
Save ppamorim/8294fde7025cb66bb8639551c8abac2d to your computer and use it in GitHub Desktop.
use std::sync::Arc;
use async_trait::async_trait;
use openraft::error::AppendEntriesError;
use openraft::error::InstallSnapshotError;
use openraft::error::NetworkError;
use openraft::error::RPCError;
use openraft::error::RemoteError;
use openraft::error::VoteError;
use openraft::raft::AppendEntriesRequest;
use openraft::raft::AppendEntriesResponse;
use openraft::raft::InstallSnapshotRequest;
use openraft::raft::InstallSnapshotResponse;
use openraft::raft::VoteRequest;
use openraft::raft::VoteResponse;
use openraft::NodeId;
use openraft::RaftNetwork;
use serde::Deserialize;
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::error::Error;
use crate::store::ExampleRequest;
use crate::MemStore;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, thiserror::Error)]
pub enum StateMachineError {
#[error("Node not found in the state machine")]
NodeNotFound(NodeId),
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, thiserror::Error)]
pub enum RPCNetworkError<T: Error> {
RPCError(RPCError<T>),
StateMachineError(StateMachineError)
}
pub struct RPCNetwork {
pub store: Arc<MemStore>,
}
impl RPCNetwork {
pub async fn send_rpc<Req, Resp, Err>(
&self,
target: NodeId,
uri: &str,
req: Req
) -> Result<Resp, RPCNetworkError<Err>>
where
Req: Serialize,
Err: std::error::Error + DeserializeOwned,
Resp: DeserializeOwned,
{
let state_machine = self.store.state_machine.read().await;
let addr = state_machine.nodes
.get(&target)
.ok_or(RPCNetworkError::StateMachineError(StateMachineError::NodeNotFound(target)))?
.clone();
let url = format!("http://{}/{}", addr, uri);
let client = reqwest::Client::new();
let resp = client
.post(url)
.json(&req)
.send()
.await
.map_err(|e| RPCError::Network(NetworkError::new(&e)))?;
let res: Result<Resp, Err> = resp
.json()
.await
.map_err(|e| RPCError::Network(NetworkError::new(&e)))?;
res.map_err(|e|
RPCNetworkError::RPCError(RPCError::RemoteError(RemoteError::new(target, e)))
)
}
}
#[async_trait]
impl RaftNetwork<ExampleRequest> for RPCNetwork {
async fn send_append_entries(
&self,
target: NodeId,
req: AppendEntriesRequest<ExampleRequest>,
) -> Result<AppendEntriesResponse, RPCError<AppendEntriesError>> {
self.send_rpc(target, "raft-append", req).await
}
async fn send_install_snapshot(
&self,
target: NodeId,
req: InstallSnapshotRequest,
) -> Result<InstallSnapshotResponse, RPCError<InstallSnapshotError>> {
self.send_rpc(target, "raft-snapshot", req).await
}
async fn send_vote(
&self, target: NodeId,
req: VoteRequest
) -> Result<VoteResponse, RPCError<VoteError>> {
self.send_rpc(target, "raft-vote", req).await
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment