Skip to content

Instantly share code, notes, and snippets.

@mooreniemi
Created May 20, 2024 03:50
Show Gist options
  • Save mooreniemi/08ab3dde1c8236980f0f86df33ab17e5 to your computer and use it in GitHub Desktop.
Save mooreniemi/08ab3dde1c8236980f0f86df33ab17e5 to your computer and use it in GitHub Desktop.
[package]
name = "cete_node"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
actix-web = "4"
aws-config = "*"
aws-sdk-codecommit = "*"
aws-sdk-ecs = "*"
clap = { version = "4.5", features = ["derive"] }
clap_derive = "4.5.4"
env_logger = "0.9"
etcd-client = "0.12"
log = "0.4"
port_scanner = "*"
reqwest = { version = "0.11", features = ["json"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
tokio = { version = "1", features = ["full"] }
use actix_web::{
middleware,
web::{self},
App, HttpResponse, HttpServer, Responder,
};
use aws_config::meta::region::RegionProviderChain;
use aws_sdk_codecommit as codecommit;
use clap::Parser;
use etcd_client::{Client, Compare, CompareOp, Txn, TxnOp};
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::{
collections::BTreeMap,
env, fmt,
net::TcpListener,
sync::{Arc, Mutex},
};
use tokio::{
signal::unix::{signal, SignalKind},
sync::Mutex as AsyncMutex,
};
/// How many tasks are serving this shard?
#[derive(Serialize, Deserialize, Debug)]
struct ShardData {
count: i32,
}
/// Which shard is this task assigned to serve?
#[derive(Serialize, Deserialize, Debug)]
struct TaskData {
shard_id: usize,
}
async fn is_task_up(port: u16) -> bool {
let url = format!("http://localhost:{}/me", port);
match reqwest::get(url).await {
Ok(response) => response.status().as_u16() == 200,
Err(e) => {
// note: we expect this error so we don't log as error here
log::debug!("{}", e);
false
}
}
}
async fn show_tasks(data: web::Data<AppState>) -> impl Responder {
log::info!("showing tasks across the entire cluster");
let mut client = data.etcd.lock().await;
// note: less confusing if output is in sorted order always
let mut tasks = BTreeMap::new();
let prefix = "task/";
log::info!("searching prefix: {}", prefix);
let response = client
.get(prefix, Some(etcd_client::GetOptions::new().with_prefix()))
.await
.expect("got keys");
log::info!("found some keys");
for kv in response.kvs() {
let key = String::from_utf8_lossy(kv.key());
let port = key
.split("/")
.nth(1)
.unwrap_or_default()
.parse::<u16>()
.expect("valid port");
log::info!("checking key: {}", key);
let value = serde_json::from_slice::<TaskData>(&kv.value())
.expect("valid shard data")
.shard_id;
tasks.insert(key, (value, is_task_up(port).await));
}
HttpResponse::Ok().json(tasks)
}
async fn show_assignments(data: web::Data<AppState>) -> impl Responder {
log::info!("showing assignments across the entire cluster");
let mut client = data.etcd.lock().await;
// note: less confusing if output is in sorted order always
let mut shards = BTreeMap::new();
for n in 0..data.total_shards {
let prefix = format!("shard_counts/shard_{}", n);
log::info!("searching prefix: {}", prefix);
let response = client
.get(
prefix.clone(),
Some(etcd_client::GetOptions::new().with_prefix()),
)
.await
.expect("got keys");
log::info!("found some keys");
for kv in response.kvs() {
let key = String::from_utf8_lossy(kv.key());
log::info!("checking key: {}", key);
let value = serde_json::from_slice::<ShardData>(&kv.value())
.expect("valid shard data")
.count;
shards.insert(prefix.clone(), value);
}
}
HttpResponse::Ok().json(shards)
}
async fn show_assignment(data: web::Data<AppState>) -> impl Responder {
let mut shard_assignment = data.shard_id.lock().expect("got shard_id");
let mut shard_assignment_value = shard_assignment.clone();
if shard_assignment_value.is_none() {
log::warn!("initializing shard id since none was found");
let new_shard = {
determine_shard_assignment(
&data.etcd,
data.total_shards,
data.replicas_per_shard,
data.port,
)
.await
.expect("can determine shard")
};
shard_assignment_value = Some(new_shard);
*shard_assignment = shard_assignment_value;
}
HttpResponse::Ok().json(json!({
"shard": shard_assignment_value.unwrap(),
"task": data.port,
"role": data.role,
}))
}
async fn cleanup_assignment(etcd: &AsyncMutex<Client>, port: u16) {
log::info!("I am task: {}. Shutting down.", port);
let mut client = etcd.lock().await;
let task_key = format!("task/{}", port);
let response = client
.get(task_key.clone(), None)
.await
.expect("got task info");
let shard_id = response
.kvs()
.get(0)
.map(|kv| {
serde_json::from_slice::<TaskData>(&kv.value())
.unwrap()
.shard_id
})
.unwrap_or(0) as usize;
log::info!(
"As task {}, I handled {}. This assignment will be released.",
port,
shard_id
);
let shard_key = format!("shard_counts/shard_{}", shard_id);
let response = client
.get(shard_key.clone(), None)
.await
.expect("got shard counts");
let count = response
.kvs()
.get(0)
.map(|kv| {
serde_json::from_slice::<ShardData>(&kv.value())
.unwrap()
.count
})
.expect("should have non-zero shards counted") as usize;
log::info!(
"As task {}, I handled {}, which will have count {} after release.",
port,
shard_id,
count - 1
);
let value = json!({ "count": count - 1 }).to_string();
client
.put(shard_key.clone(), value, None)
.await
.expect("decremented count");
client.delete(task_key, None).await.expect("deleted task");
log::info!("Task {} finished, shard released.", port);
}
// given how the calculation strategy works, you'll almost always be yourself again
async fn recalculate_assignment(data: web::Data<AppState>) -> impl Responder {
let response = match determine_shard_assignment(
&data.etcd,
data.total_shards,
data.replicas_per_shard,
data.port,
)
.await
{
Ok(shard_id) => {
json!({"shard_id": shard_id})
}
Err(e) => {
json!({"error": format!("{}", e)})
}
};
HttpResponse::Ok().json(response)
}
// uses transactions to safely attempt to initialize or increment shard
// we don't retry in this function because if the cas failed we know another node took the assignment
// so we actually need to totally bail out here and continue to the next potentially available slot
async fn increment_shard_count(
client: &mut Client,
key: &str,
potential_initial_value: Option<usize>,
) -> Result<(), Box<dyn std::error::Error>> {
let key = key.to_string();
log::info!(
"will increment {:?}, from {:?}",
key,
potential_initial_value
);
// note: creating the data for the first time must be handled differently than mutating it once it exists
let txn = match potential_initial_value {
Some(iv) => {
let initial_value_as_json = json!({"count": iv + 1}).to_string();
let txn = Txn::new();
let cmp = Compare::value(key.clone(), CompareOp::Equal, initial_value_as_json.clone());
let succ = TxnOp::put(key.clone(), initial_value_as_json, None);
let fail = TxnOp::get(key.clone(), None);
txn.when(vec![cmp]).and_then(vec![succ]).or_else(vec![fail])
}
None => {
let initial_value_as_json = json!({"count": 1}).to_string();
let txn = Txn::new();
// note: version, not value here - we're checking that the key does not exist still
let cmp = Compare::version(key.clone(), CompareOp::Equal, 0);
let succ = TxnOp::put(key.clone(), initial_value_as_json, None);
let fail = TxnOp::get(key.clone(), None);
txn.when(vec![cmp]).and_then(vec![succ]).or_else(vec![fail])
}
};
let txn_resp = client
.txn(txn)
.await
.expect("got etcd response successfully");
log::debug!("finished transaction: {:?}", txn_resp);
if txn_resp.succeeded() {
Ok(())
} else {
Err("transaction was not successful".into())
}
}
async fn determine_shard_assignment(
etcd: &AsyncMutex<Client>,
total_shards: usize,
replicas_per_shard: usize,
port: u16,
) -> Result<usize, Box<dyn std::error::Error>> {
log::info!("time to determine the shard");
let mut client = etcd.lock().await;
// note: essentially does greedy placement,
// finding the first shard without a complete replica set
// and adding this task id to the replica set
for shard_id in 0..total_shards {
let key = format!("shard_counts/shard_{}", shard_id);
// note: between this read and the update, values can change
// this is why we can need to move on to the next shard
let response = client.get(key.clone(), None).await?;
let maybe_count = response.kvs().get(0).map(|kv| {
serde_json::from_slice::<ShardData>(&kv.value())
.unwrap()
.count as usize
});
let count = maybe_count.unwrap_or(0);
if count < replicas_per_shard {
// note: has to take mut client here or will deadlock with the above lock taken on etcd
match increment_shard_count(&mut client, &key, maybe_count).await {
Ok(_) => {
log::info!("adding replica to count for {}", shard_id);
// note: store this so we can use it to look up and decrement later on shutdown
let value = json!({ "shard_id": shard_id}).to_string();
client.put(format!("task/{}", port), value, None).await?;
return Ok(shard_id);
}
Err(_) => {
log::info!(
"another task stole the assignment for {}, trying another shard",
shard_id
);
continue;
}
}
}
}
Err("No available shards".into())
}
// not using this locally but just have it for later if I want
async fn fetch_config_from_codecommit(
make_outbound_request: bool,
default_number_shards: usize,
default_replicas_per_shard: usize,
) -> Result<(usize, usize), Box<dyn std::error::Error>> {
if make_outbound_request {
let region_provider = RegionProviderChain::default_provider().or_else("us-west-2");
let config = aws_config::from_env().region(region_provider).load().await;
let client = codecommit::Client::new(&config);
let content = client
.get_file()
.repository_name("my-config-repo")
.file_path("config.json")
.send()
.await?
.file_content()
.as_ref()
.to_vec();
let config_data: serde_json::Value = serde_json::from_slice(&content)?;
let total_shards = config_data["total_shards"]
.as_u64()
.expect("Expected total_shards") as usize;
let replicas_per_shard = config_data["replicas_per_shard"]
.as_i64()
.expect("Expected replicas_per_shard") as usize;
Ok((total_shards, replicas_per_shard))
} else {
let total_shards = default_number_shards;
// these start at 1, that is, the primary is the first replica
let replicas_per_shard = default_replicas_per_shard;
Ok((total_shards, replicas_per_shard))
}
}
struct AppState {
etcd: Arc<AsyncMutex<Client>>,
total_shards: usize,
replicas_per_shard: usize,
/// the fixed identity of the task
port: u16,
/// which type of node this is
role: String,
/// the dynamic and assigned identity of the shard
shard_id: Mutex<Option<usize>>,
}
impl fmt::Debug for AppState {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AppState")
.field("etcd", &"Arc<AsyncMutext<Client>>> { ... }") // Custom display for Inner
.field("total_shards", &self.total_shards)
.field("replicas_per_shard", &self.replicas_per_shard)
.field("port", &self.port)
.finish()
}
}
#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
struct Args {
/// Role of the cete node
#[arg(short, long)]
role: String,
/// total number distinct shards
#[arg(short, long, default_value_t = 1)]
shards: usize,
/// total number replicas per shard
#[arg(long, default_value_t = 1)]
replicas: usize,
}
#[tokio::main]
async fn main() -> std::io::Result<()> {
env_logger::init();
let args = Args::parse();
let (total_shards, replicas_per_shard) =
fetch_config_from_codecommit(false, args.shards, args.replicas)
.await
.expect("Failed to fetch configuration");
let etcd_address = env::var("ETCD_ADDRESS").unwrap_or_else(|_| "http://localhost:2379".into());
let client = Client::connect([etcd_address], None)
.await
.expect("connect to ectd");
let (listener, port) = match args.role.as_str() {
// note: just easier to test with when we have at least one node with a fixed address
"front" => {
let port = 3000;
let listener = TcpListener::bind("0.0.0.0:3000").expect("Failed to bind to front port");
(listener, port)
}
"inner" => {
let listener =
TcpListener::bind("0.0.0.0:0").expect("Failed to bind to inner, ephemeral port");
let port = listener.local_addr().unwrap().port();
(listener, port)
}
_ => {
todo!("no other roles")
}
};
let etcd = Arc::new(AsyncMutex::new(client));
let binding = etcd.clone();
let shard_id = determine_shard_assignment(&binding, total_shards, replicas_per_shard, port)
.await
.expect("assigned_shard");
let data = web::Data::new(AppState {
etcd: etcd.clone(),
total_shards,
replicas_per_shard,
port,
role: args.role,
shard_id: Mutex::new(Some(shard_id)),
});
log::info!("AppState: {:?}", &data);
// note: actix doesn't have its own shutdown hook, so we listen for signals and do cleanup manually
tokio::spawn(async move {
let mut terminate = signal(SignalKind::terminate()).unwrap();
let mut interrupt = signal(SignalKind::interrupt()).unwrap();
tokio::select! {
_ = terminate.recv() => {
println!("Received SIGTERM signal, shutting down...");
cleanup_assignment(&etcd, port).await;
}
_ = interrupt.recv() => {
println!("Received SIGINT signal, shutting down...");
cleanup_assignment(&etcd, port).await;
}
}
});
HttpServer::new(move || {
App::new()
.wrap(middleware::Logger::default())
.app_data(data.clone())
.route("/me", web::get().to(show_assignment))
.route("/", web::get().to(show_assignments))
.route("/tasks", web::get().to(show_tasks))
.route("/redo", web::get().to(recalculate_assignment))
})
.listen(listener)?
.run()
.await
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment