Skip to content

Instantly share code, notes, and snippets.

@frederikbosch
Created February 3, 2020 11:08
Show Gist options
  • Save frederikbosch/de51d7d329c5c2097e72361f1c4a0764 to your computer and use it in GitHub Desktop.
Save frederikbosch/de51d7d329c5c2097e72361f1c4a0764 to your computer and use it in GitHub Desktop.
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::Mutex;
use tokio::sync::{mpsc, watch};
use tokio::sync::RwLock;
use tonic::{Code, Request, Response, Status};
use def::health_check_response::ServingStatus;
use def::health_server;
use def::{HealthCheckRequest, HealthCheckResponse};
pub mod def {
tonic::include_proto!("grpc.health.v1");
}
type HealthResult<T> = Result<Response<T>, Status>;
type ResponseStream<T> = mpsc::Receiver<Result<T, Status>>;
pub struct ServiceRegister {
services: Mutex<HashMap<String, ServingStatus>>,
signaller: Mutex<watch::Sender<Option<(String, ServingStatus)>>>,
subscriber: watch::Receiver<Option<(String, ServingStatus)>>,
}
impl ServiceRegister {
pub fn new(initial: ServingStatus) -> Self {
let mut services = HashMap::new();
services.insert(String::new(), initial);
let (tx, rx) = watch::channel(None);
ServiceRegister {
services: Mutex::new(services),
signaller: Mutex::new(tx),
subscriber: rx,
}
}
pub fn get_status(&self, service_name: &str) -> Option<ServingStatus> {
if let Ok(services) = self.services.lock() {
services.get(service_name).copied()
} else {
None
}
}
pub fn set_status(&mut self, service_name: &str, status: ServingStatus) -> Result<(), ()> {
if let Ok(writer) = self.services.get_mut() {
let last = writer.insert(service_name.to_string(), status);
if last.is_none() || last.unwrap() != status {
if let Ok(sig) = self.signaller.lock() {
sig.broadcast(Some((service_name.to_string(), status)))
.map_err(|_| ())?;
}
}
Ok(())
} else {
Err(())
}
}
pub fn subscribe (&self) -> watch::Receiver<Option<(String, ServingStatus)>> {
self.subscriber.clone()
}
pub fn shutdown(&mut self) -> Result<(), ()> {
if let Ok(writer) = self.services.get_mut() {
for (name, status) in writer.iter_mut() {
if *status != ServingStatus::NotServing {
*status = ServingStatus::NotServing;
if let Ok(sig) = self.signaller.lock() {
sig.broadcast(Some((name.to_string(), *status))).map_err(|_| ())?;
}
}
}
Ok(())
} else {
Err(())
}
}
}
pub struct HealthCheckService {
register: Arc<RwLock<ServiceRegister>>,
}
impl HealthCheckService {
pub fn new(register: Arc<RwLock<ServiceRegister>>) -> Self {
HealthCheckService { register }
}
}
#[tonic::async_trait]
impl health_server::Health for HealthCheckService {
async fn check(
&self,
request: Request<HealthCheckRequest>,
) -> HealthResult<HealthCheckResponse> {
match self.register.read().await.get_status(&request.get_ref().service) {
Some(status) => {
let response = Response::new(HealthCheckResponse {
status: status.clone() as i32,
});
Ok(response)
}
None => Err(Status::new(Code::NotFound, "")),
}
}
type WatchStream = ResponseStream<HealthCheckResponse>;
async fn watch(&self, request: Request<HealthCheckRequest>) -> HealthResult<Self::WatchStream> {
let name = &request.get_ref().service;
let (mut tx, res_rx) = mpsc::channel(10);
if let Some(status) = self.register.read().await.get_status(name) {
let _ = tx.send(Ok(HealthCheckResponse {
status: status as i32,
}))
.await;
let mut rx = self.register.read().await.subscribe();
while let Some(value) = rx.recv().await {
match value {
Some((changed_name, status)) => {
if *name == changed_name {
let _ = tx.send(Ok(HealthCheckResponse {
status: status as i32,
}))
.await;
}
},
_ => {},
}
}
Ok(Response::new(res_rx))
} else {
Err(Status::new(Code::NotFound, ""))
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment