-
-
Save Alonely0/6c58850fecc587c5d7496c6ab3cb9cd3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use async_scoped::TokioScope; | |
use bincode::{deserialize, serialize}; | |
use fs4::tokio::AsyncFileExt; | |
use serde::*; | |
use std::{ | |
collections::HashSet, | |
fmt::Debug, | |
marker::PhantomData, | |
mem::{variant_count, MaybeUninit}, | |
ops::Deref, | |
path::{Path, PathBuf}, | |
time::Duration, | |
}; | |
use tokio::{ | |
fs::{remove_file, try_exists, File as TokioFile}, | |
io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter}, | |
spawn, | |
sync::{ | |
mpsc::{channel, Receiver, Sender}, | |
Mutex, | |
}, | |
time::timeout, | |
}; | |
// use tokio_uring::fs::File as UringFile; | |
use tracing::log::trace; | |
use anyhow::{anyhow, Result}; | |
use dashmap::{DashMap, DashSet}; | |
use tokio::sync::RwLock; | |
// todo use flume instead of [`std::sync::mpsc`] | |
pub trait Permissions { | |
fn has_ro_access(&self, id: &u128) -> bool; | |
fn has_rw_access(&self, id: &u128) -> bool; | |
} | |
#[repr(transparent)] | |
#[derive(Serialize, Deserialize, Debug, Clone, Copy, Default)] | |
pub struct UUIDRef<T>(pub u128, pub PhantomData<T>); | |
impl<T> UUIDRef<T> { | |
pub fn new(id: u128) -> Self { | |
Self(id, PhantomData) | |
} | |
} | |
impl Permissions for Alumne { | |
fn has_ro_access(&self, id: &u128) -> bool { | |
self.ro_access.contains(id) | |
} | |
fn has_rw_access(&self, id: &u128) -> bool { | |
self.rw_access.contains(id) | |
} | |
} | |
impl Permissions for Classe { | |
fn has_ro_access(&self, id: &u128) -> bool { | |
self.ro_access.contains(id) | |
} | |
fn has_rw_access(&self, id: &u128) -> bool { | |
self.rw_access.contains(id) | |
} | |
} | |
pub type AlumneRef = UUIDRef<Alumne>; | |
pub type CursRef = UUIDRef<Curs>; | |
pub type ProfessorRef = UUIDRef<Professor>; | |
pub type ClasseRef = UUIDRef<Classe>; | |
#[derive(Serialize, Deserialize, Debug, Clone, Default)] | |
#[serde(default)] | |
pub struct Alumne { | |
pub nom: String, | |
pub curs: CursRef, | |
pub classes: Vec<ClasseRef>, | |
pub faltes: Expedient, | |
pub ro_access: DashSet<u128>, | |
pub rw_access: DashSet<u128>, | |
pub uuid: u128, | |
} | |
#[derive(Serialize, Deserialize, Debug, Clone, Default)] | |
#[serde(default)] | |
pub struct Professor { | |
nom: String, | |
classes: Vec<ClasseRef>, | |
uuid: u128, | |
} | |
#[derive(Serialize, Deserialize, Debug, Clone, Default)] | |
#[serde(default)] | |
pub struct Classe { | |
pub alumnes: Vec<AlumneRef>, | |
pub assignatura: String, | |
pub professors: Vec<ProfessorRef>, | |
pub horari: Horari, | |
pub ro_access: HashSet<u128>, | |
pub rw_access: HashSet<u128>, | |
pub uuid: u64, | |
} | |
#[derive(Serialize, Deserialize, Debug, Clone, Default)] | |
#[serde(default)] | |
pub struct Curs { | |
pub classes: Vec<ClasseRef>, | |
pub uuid: u128, | |
} | |
#[derive(Serialize, Deserialize, Debug, Clone, Default)] | |
#[serde(default)] | |
pub struct TutorLegal { | |
pub telefon: String, | |
pub email: String, | |
} | |
#[derive(Serialize, Deserialize, Debug, Clone, Default)] | |
#[serde(default)] | |
pub struct Falta { | |
pub discriminant: u128, | |
pub assignatura: String, | |
pub professsor: ProfessorRef, | |
pub quan: (f32, Setmana), | |
pub missatge: String, | |
} | |
#[derive(Serialize, Deserialize, Debug, Clone, Default)] | |
pub enum Setmana { | |
Dilluns, | |
Dimarts, | |
Dimecres, | |
Dijous, | |
Divendres, | |
Dissabte, | |
Diumenge, | |
#[default] | |
#[allow(non_camel_case_types)] | |
__DEFAULT_DISAMBIGUATION, | |
} | |
#[derive(Serialize, Deserialize, Debug, Clone, Default)] | |
pub struct Horari(pub Vec<(f32, Setmana)>); | |
#[derive(Serialize, Deserialize, Debug, Clone, Default)] | |
#[serde(default)] | |
pub struct Expedient {} // todo | |
/// The discriminant is exlpicitly marked becaouse the compiler doesn't | |
/// guarantee them being sequential from 0 to repr::MAX. | |
/// All the fields are inside a [`DBList`] because that way all are of | |
/// the same size and the discriminant appears only once per table. | |
#[derive(Serialize, Deserialize, Debug)] | |
#[non_exhaustive] | |
#[repr(C, usize)] | |
pub enum DBTable { | |
Alumnes(DBList<Alumne>) = 0, | |
Classes(DBList<Classe>) = 1, | |
Cursos(DBList<Curs>) = 2, | |
Professors(DBList<Professor>) = 3, | |
ProfTokens(DBList<u128>) = 4, | |
TutorLegals(DBList<TutorLegal>) = 5, | |
TutorsLegalsTokens(DBList<u128>) = 6, | |
GlobalRw(DBList<()>) = 7, | |
GlobalRo(DBList<()>) = 8, | |
} | |
#[repr(C)] | |
struct DBTableLayout<T> | |
where | |
T: Clone + Serialize + for<'de> Deserialize<'de>, | |
{ | |
tag: usize, | |
payload: DBList<T>, | |
} | |
impl DBTable { | |
/// # Safety | |
/// T being the correct type for the enum variant must be uphold by the caller. | |
/// This contract is always upheld by the macro [`crate::db`]. | |
pub unsafe fn get_table<T>(&self) -> &DBList<T> | |
where | |
T: Clone + Serialize + for<'de> Deserialize<'de>, | |
{ | |
&unsafe { &*(self as *const _ as *const DBTableLayout<T>) }.payload | |
} | |
pub fn tag(&self) -> usize { | |
unsafe { &*(self as *const _ as *const DBTableLayout<()>) }.tag | |
} | |
} | |
#[repr(transparent)] | |
#[derive(Serialize, Deserialize, Debug)] | |
pub struct Cache(Vec<DBTable>); | |
impl Cache { | |
pub fn new() -> Self { | |
const COUNT: usize = variant_count::<DBTable>(); | |
trace!("`DBTable` variants: {COUNT}"); | |
let (ptr, _, _) = Vec::<DBTableLayout<()>>::with_capacity(COUNT).into_raw_parts(); | |
for i in 0..COUNT { | |
trace!("Writing variant {i} to `ptr.add({i})`"); | |
unsafe { | |
ptr.add(i).write(DBTableLayout { | |
tag: i, | |
payload: DBList::<()>(DashMap::with_capacity(0)), | |
}) | |
} | |
} | |
trace!("Finished initializing the DB Cache"); | |
Self(unsafe { Vec::from_raw_parts(ptr.cast(), COUNT, COUNT) }) | |
} | |
fn migrate(&mut self) { | |
trace!("Checking for needed migrations"); | |
let diff = variant_count::<DBTable>() - self.0.len(); | |
if diff > 0 { | |
trace!("Migrating db with diff of {diff}"); | |
self.0.reserve_exact(diff); | |
for i in self.0.len()..self.0.capacity() { | |
trace!("Writing variant {i} to `ptr.add({i})`"); | |
unsafe { | |
self.0 | |
.as_mut_ptr() | |
.cast::<DBTableLayout<()>>() | |
.add(i) | |
.write(DBTableLayout { | |
tag: i, | |
payload: DBList::<()>(DashMap::with_capacity(0)), | |
}) | |
} | |
} | |
} | |
} | |
async fn to_bytes(&self) -> Result<Vec<u8>> { | |
let mut buf = Err(anyhow!("Couldn't serialize the db")); | |
TokioScope::scope_and_block(|s| { | |
s.spawn_blocking(|| buf = serialize(&self).map_err(|e| anyhow!(e))) | |
}); | |
// serialize(&self).map_err(|e| anyhow!(e)) | |
buf | |
} | |
} | |
impl Deref for Cache { | |
type Target = Vec<DBTable>; | |
fn deref(&self) -> &Self::Target { | |
&self.0 | |
} | |
} | |
pub enum WatchdogMessage { | |
Ping, | |
Force, | |
} | |
#[derive(Debug)] | |
pub struct DB { | |
path: PathBuf, | |
pub memory_cache: Cache, | |
pub watchdog_sx: Sender<WatchdogMessage>, | |
pub watchdog_rx: Mutex<Receiver<WatchdogMessage>>, | |
} | |
impl DB { | |
pub async fn open(path: impl AsRef<Path>) -> Result<&'static Self> { | |
let cache = if !try_exists(&path).await? { | |
let file = TokioFile::create(&path).await?; | |
#[cfg(not(miri))] | |
file.lock_exclusive()?; | |
let mut writer = BufWriter::with_capacity(file.metadata().await?.len() as usize, file); | |
let new_db = Cache::new(); | |
writer.write_all(&new_db.to_bytes().await?).await?; | |
writer.shutdown().await?; | |
new_db | |
} else { | |
let file = TokioFile::open(&path).await?; | |
#[cfg(not(miri))] | |
file.lock_exclusive()?; | |
let mut buf = Vec::new(); | |
let mut reader = BufReader::with_capacity(file.metadata().await?.len() as usize, file); | |
reader.read_to_end(&mut buf).await?; | |
let mut db = deserialize::<Cache>(&buf)?; | |
db.migrate(); | |
reader.into_inner().shutdown().await?; | |
db | |
}; | |
let (sx, rx) = channel(3000); | |
let db: &_ = Box::leak(Box::new(Self { | |
path: path.as_ref().to_owned(), | |
memory_cache: cache, | |
watchdog_sx: sx, | |
watchdog_rx: Mutex::new(rx), | |
})); | |
spawn(db.watchdog()); | |
Ok(db) | |
} | |
pub async fn notify_watchdog(&self, msg: WatchdogMessage) { | |
unsafe { self.watchdog_sx.send(msg).await.unwrap_unchecked() } | |
} | |
async fn watchdog(&self) { | |
let mut receiver = self.watchdog_rx.lock().await; | |
// receiver.recv().await; | |
loop { | |
match timeout(Duration::from_millis(2 * 60 * 1000), receiver.recv()).await { | |
Ok(Some(WatchdogMessage::Ping)) => continue, | |
Ok(Some(WatchdogMessage::Force)) | Err(_) if self.flush().await.is_ok() => { | |
while receiver.try_recv().is_ok() {} | |
receiver.recv().await; | |
} | |
_ => { | |
panic!("Database exploded"); | |
} | |
} | |
} | |
} | |
async fn flush(&self) -> Result<()> { | |
// fixme switch to [`UringFile`] when [`tokio_uring`] becomes Send | |
remove_file(&self.path).await?; | |
let file = TokioFile::create(&self.path).await?; | |
let mut writer = BufWriter::with_capacity(file.metadata().await?.len() as usize, file); | |
writer | |
.write_all(&self.memory_cache.to_bytes().await?) | |
.await?; | |
writer.flush().await?; | |
writer.shutdown().await?; | |
Ok(()) | |
} | |
} | |
#[repr(transparent)] | |
#[derive(Default, Debug)] | |
pub struct DBList<T>(pub DashMap<u128, RwLock<T>>) | |
where | |
T: Clone + Serialize + for<'a> Deserialize<'a>; | |
impl<T> Serialize for DBList<T> | |
where | |
T: Clone + Serialize + for<'a> Deserialize<'a>, | |
{ | |
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> | |
where | |
S: Serializer, | |
{ | |
Serialize::serialize( | |
&self | |
.0 | |
.iter() | |
.map(|x| (*x.key(), x.value().blocking_read().clone())) | |
.collect::<DashMap<u128, T>>(), | |
serializer, | |
) | |
} | |
} | |
impl<'de, T> Deserialize<'de> for DBList<T> | |
where | |
T: Clone + Serialize + for<'a> Deserialize<'a>, | |
{ | |
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> | |
where | |
D: Deserializer<'de>, | |
{ | |
let val: DashMap<u128, T> = Deserialize::deserialize(deserializer)?; | |
Ok(Self( | |
val.into_iter().map(|(i, x)| (i, RwLock::new(x))).collect(), | |
)) | |
} | |
} | |
impl<T> DBList<T> | |
where | |
T: Clone + Serialize + for<'a> Deserialize<'a>, | |
{ | |
pub fn insert(&self, element: T) -> u128 { | |
let uuid = rand::random(); | |
self.0.insert(uuid, RwLock::new(element)); | |
uuid | |
} | |
} | |
pub(crate) mod macros { | |
#![allow(dead_code)] | |
#[macro_export] | |
macro_rules! db { | |
() => { | |
#[allow(unused_unsafe)] | |
unsafe { | |
*$crate::DATABASE.get().unwrap_unchecked() | |
} | |
}; | |
(Table::$x:ident) => { | |
db!(::Table::$x).0 | |
}; | |
(::Table::$x:ident) => { | |
#[allow(unused_unsafe)] | |
unsafe { | |
db!().memory_cache[DBTable::$x($crate::database::macros::zero()).tag()] | |
.get_table::<$x>() | |
} | |
}; | |
} | |
/// HACK this is used to build an enum with the correct variant | |
/// on the [`db`] macro and get the discriminant. The optimizer | |
/// probably will remove the function call since its return | |
/// value is never read. | |
#[inline(always)] | |
pub const unsafe fn zero<T>() -> T { | |
MaybeUninit::zeroed().assume_init() | |
} | |
use super::*; | |
pub type Alumnes = Alumne; | |
pub type Classes = Classe; | |
pub type Cursos = Curs; | |
pub type Professors = Professor; | |
pub type ProfTokens = u128; | |
pub type TutorLegals = TutorLegal; | |
pub type TutorsLegalsTokens = u128; | |
pub type GlobalRw = (); | |
pub type GlobalRo = (); | |
} | |
pub use macros::*; | |
#[cfg(test)] | |
mod tests { | |
use crate::{db, DATABASE}; | |
use tokio::{fs::remove_file, time::sleep}; | |
use super::*; | |
#[tokio::test(flavor = "multi_thread", worker_threads = 2)] | |
async fn read_write_db() { | |
remove_file("./temp_db").await.unwrap_or(()); | |
DB::open("./temp_db") | |
.await | |
.and_then(|db| DATABASE.set(db).map_err(|e| anyhow!("{e:?}"))) | |
.expect("Can't open the database"); | |
const INSTANCES: usize = 4; | |
for _ in 0..INSTANCES { | |
db!(Table::Alumnes).insert(rand::random::<u128>(), RwLock::new(Alumne::default())); | |
} | |
db!().notify_watchdog(WatchdogMessage::Force).await; | |
sleep(Duration::from_millis(1000)).await; | |
DB::open("./temp_db") | |
.await | |
.inspect(|&db| { | |
assert_eq!( | |
unsafe { | |
db.memory_cache[DBTable::Alumnes(zero()).tag()] | |
.get_table::<Alumne>() | |
.0 | |
.len() | |
}, | |
INSTANCES | |
) | |
}) | |
.expect("Can't open the database"); | |
remove_file("./temp_db").await.unwrap_or(()); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment