-
-
Save ppamorim/8294fde7025cb66bb8639551c8abac2d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::sync::Arc; | |
use async_trait::async_trait; | |
use openraft::error::AppendEntriesError; | |
use openraft::error::InstallSnapshotError; | |
use openraft::error::NetworkError; | |
use openraft::error::RPCError; | |
use openraft::error::RemoteError; | |
use openraft::error::VoteError; | |
use openraft::raft::AppendEntriesRequest; | |
use openraft::raft::AppendEntriesResponse; | |
use openraft::raft::InstallSnapshotRequest; | |
use openraft::raft::InstallSnapshotResponse; | |
use openraft::raft::VoteRequest; | |
use openraft::raft::VoteResponse; | |
use openraft::NodeId; | |
use openraft::RaftNetwork; | |
use serde::Deserialize; | |
use serde::de::DeserializeOwned; | |
use serde::Serialize; | |
use std::error::Error; | |
use crate::store::ExampleRequest; | |
use crate::MemStore; | |
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, thiserror::Error)] | |
pub enum StateMachineError { | |
#[error("Node not found in the state machine")] | |
NodeNotFound(NodeId), | |
} | |
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, thiserror::Error)] | |
pub enum RPCNetworkError<T: Error> { | |
RPCError(RPCError<T>), | |
StateMachineError(StateMachineError) | |
} | |
pub struct RPCNetwork { | |
pub store: Arc<MemStore>, | |
} | |
impl RPCNetwork { | |
pub async fn send_rpc<Req, Resp, Err>( | |
&self, | |
target: NodeId, | |
uri: &str, | |
req: Req | |
) -> Result<Resp, RPCNetworkError<Err>> | |
where | |
Req: Serialize, | |
Err: std::error::Error + DeserializeOwned, | |
Resp: DeserializeOwned, | |
{ | |
let state_machine = self.store.state_machine.read().await; | |
let addr = state_machine.nodes | |
.get(&target) | |
.ok_or(RPCNetworkError::StateMachineError(StateMachineError::NodeNotFound(target)))? | |
.clone(); | |
let url = format!("http://{}/{}", addr, uri); | |
let client = reqwest::Client::new(); | |
let resp = client | |
.post(url) | |
.json(&req) | |
.send() | |
.await | |
.map_err(|e| RPCError::Network(NetworkError::new(&e)))?; | |
let res: Result<Resp, Err> = resp | |
.json() | |
.await | |
.map_err(|e| RPCError::Network(NetworkError::new(&e)))?; | |
res.map_err(|e| | |
RPCNetworkError::RPCError(RPCError::RemoteError(RemoteError::new(target, e))) | |
) | |
} | |
} | |
#[async_trait] | |
impl RaftNetwork<ExampleRequest> for RPCNetwork { | |
async fn send_append_entries( | |
&self, | |
target: NodeId, | |
req: AppendEntriesRequest<ExampleRequest>, | |
) -> Result<AppendEntriesResponse, RPCError<AppendEntriesError>> { | |
self.send_rpc(target, "raft-append", req).await | |
} | |
async fn send_install_snapshot( | |
&self, | |
target: NodeId, | |
req: InstallSnapshotRequest, | |
) -> Result<InstallSnapshotResponse, RPCError<InstallSnapshotError>> { | |
self.send_rpc(target, "raft-snapshot", req).await | |
} | |
async fn send_vote( | |
&self, target: NodeId, | |
req: VoteRequest | |
) -> Result<VoteResponse, RPCError<VoteError>> { | |
self.send_rpc(target, "raft-vote", req).await | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment