Skip to content

Instantly share code, notes, and snippets.

@Alonely0

Alonely0/db.rs Secret

Last active April 7, 2023 17:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Alonely0/6c58850fecc587c5d7496c6ab3cb9cd3 to your computer and use it in GitHub Desktop.
Save Alonely0/6c58850fecc587c5d7496c6ab3cb9cd3 to your computer and use it in GitHub Desktop.
use async_scoped::TokioScope;
use bincode::{deserialize, serialize};
use fs4::tokio::AsyncFileExt;
use serde::*;
use std::{
collections::HashSet,
fmt::Debug,
marker::PhantomData,
mem::{variant_count, MaybeUninit},
ops::Deref,
path::{Path, PathBuf},
time::Duration,
};
use tokio::{
fs::{remove_file, try_exists, File as TokioFile},
io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter},
spawn,
sync::{
mpsc::{channel, Receiver, Sender},
Mutex,
},
time::timeout,
};
// use tokio_uring::fs::File as UringFile;
use tracing::log::trace;
use anyhow::{anyhow, Result};
use dashmap::{DashMap, DashSet};
use tokio::sync::RwLock;
// todo use flume instead of [`std::sync::mpsc`]
pub trait Permissions {
fn has_ro_access(&self, id: &u128) -> bool;
fn has_rw_access(&self, id: &u128) -> bool;
}
#[repr(transparent)]
#[derive(Serialize, Deserialize, Debug, Clone, Copy, Default)]
pub struct UUIDRef<T>(pub u128, pub PhantomData<T>);
impl<T> UUIDRef<T> {
pub fn new(id: u128) -> Self {
Self(id, PhantomData)
}
}
impl Permissions for Alumne {
fn has_ro_access(&self, id: &u128) -> bool {
self.ro_access.contains(id)
}
fn has_rw_access(&self, id: &u128) -> bool {
self.rw_access.contains(id)
}
}
impl Permissions for Classe {
fn has_ro_access(&self, id: &u128) -> bool {
self.ro_access.contains(id)
}
fn has_rw_access(&self, id: &u128) -> bool {
self.rw_access.contains(id)
}
}
pub type AlumneRef = UUIDRef<Alumne>;
pub type CursRef = UUIDRef<Curs>;
pub type ProfessorRef = UUIDRef<Professor>;
pub type ClasseRef = UUIDRef<Classe>;
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
#[serde(default)]
pub struct Alumne {
pub nom: String,
pub curs: CursRef,
pub classes: Vec<ClasseRef>,
pub faltes: Expedient,
pub ro_access: DashSet<u128>,
pub rw_access: DashSet<u128>,
pub uuid: u128,
}
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
#[serde(default)]
pub struct Professor {
nom: String,
classes: Vec<ClasseRef>,
uuid: u128,
}
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
#[serde(default)]
pub struct Classe {
pub alumnes: Vec<AlumneRef>,
pub assignatura: String,
pub professors: Vec<ProfessorRef>,
pub horari: Horari,
pub ro_access: HashSet<u128>,
pub rw_access: HashSet<u128>,
pub uuid: u64,
}
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
#[serde(default)]
pub struct Curs {
pub classes: Vec<ClasseRef>,
pub uuid: u128,
}
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
#[serde(default)]
pub struct TutorLegal {
pub telefon: String,
pub email: String,
}
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
#[serde(default)]
pub struct Falta {
pub discriminant: u128,
pub assignatura: String,
pub professsor: ProfessorRef,
pub quan: (f32, Setmana),
pub missatge: String,
}
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub enum Setmana {
Dilluns,
Dimarts,
Dimecres,
Dijous,
Divendres,
Dissabte,
Diumenge,
#[default]
#[allow(non_camel_case_types)]
__DEFAULT_DISAMBIGUATION,
}
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct Horari(pub Vec<(f32, Setmana)>);
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
#[serde(default)]
pub struct Expedient {} // todo
/// The discriminant is exlpicitly marked becaouse the compiler doesn't
/// guarantee them being sequential from 0 to repr::MAX.
/// All the fields are inside a [`DBList`] because that way all are of
/// the same size and the discriminant appears only once per table.
#[derive(Serialize, Deserialize, Debug)]
#[non_exhaustive]
#[repr(C, usize)]
pub enum DBTable {
Alumnes(DBList<Alumne>) = 0,
Classes(DBList<Classe>) = 1,
Cursos(DBList<Curs>) = 2,
Professors(DBList<Professor>) = 3,
ProfTokens(DBList<u128>) = 4,
TutorLegals(DBList<TutorLegal>) = 5,
TutorsLegalsTokens(DBList<u128>) = 6,
GlobalRw(DBList<()>) = 7,
GlobalRo(DBList<()>) = 8,
}
#[repr(C)]
struct DBTableLayout<T>
where
T: Clone + Serialize + for<'de> Deserialize<'de>,
{
tag: usize,
payload: DBList<T>,
}
impl DBTable {
/// # Safety
/// T being the correct type for the enum variant must be uphold by the caller.
/// This contract is always upheld by the macro [`crate::db`].
pub unsafe fn get_table<T>(&self) -> &DBList<T>
where
T: Clone + Serialize + for<'de> Deserialize<'de>,
{
&unsafe { &*(self as *const _ as *const DBTableLayout<T>) }.payload
}
pub fn tag(&self) -> usize {
unsafe { &*(self as *const _ as *const DBTableLayout<()>) }.tag
}
}
#[repr(transparent)]
#[derive(Serialize, Deserialize, Debug)]
pub struct Cache(Vec<DBTable>);
impl Cache {
pub fn new() -> Self {
const COUNT: usize = variant_count::<DBTable>();
trace!("`DBTable` variants: {COUNT}");
let (ptr, _, _) = Vec::<DBTableLayout<()>>::with_capacity(COUNT).into_raw_parts();
for i in 0..COUNT {
trace!("Writing variant {i} to `ptr.add({i})`");
unsafe {
ptr.add(i).write(DBTableLayout {
tag: i,
payload: DBList::<()>(DashMap::with_capacity(0)),
})
}
}
trace!("Finished initializing the DB Cache");
Self(unsafe { Vec::from_raw_parts(ptr.cast(), COUNT, COUNT) })
}
fn migrate(&mut self) {
trace!("Checking for needed migrations");
let diff = variant_count::<DBTable>() - self.0.len();
if diff > 0 {
trace!("Migrating db with diff of {diff}");
self.0.reserve_exact(diff);
for i in self.0.len()..self.0.capacity() {
trace!("Writing variant {i} to `ptr.add({i})`");
unsafe {
self.0
.as_mut_ptr()
.cast::<DBTableLayout<()>>()
.add(i)
.write(DBTableLayout {
tag: i,
payload: DBList::<()>(DashMap::with_capacity(0)),
})
}
}
}
}
async fn to_bytes(&self) -> Result<Vec<u8>> {
let mut buf = Err(anyhow!("Couldn't serialize the db"));
TokioScope::scope_and_block(|s| {
s.spawn_blocking(|| buf = serialize(&self).map_err(|e| anyhow!(e)))
});
// serialize(&self).map_err(|e| anyhow!(e))
buf
}
}
impl Deref for Cache {
type Target = Vec<DBTable>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
pub enum WatchdogMessage {
Ping,
Force,
}
#[derive(Debug)]
pub struct DB {
path: PathBuf,
pub memory_cache: Cache,
pub watchdog_sx: Sender<WatchdogMessage>,
pub watchdog_rx: Mutex<Receiver<WatchdogMessage>>,
}
impl DB {
pub async fn open(path: impl AsRef<Path>) -> Result<&'static Self> {
let cache = if !try_exists(&path).await? {
let file = TokioFile::create(&path).await?;
#[cfg(not(miri))]
file.lock_exclusive()?;
let mut writer = BufWriter::with_capacity(file.metadata().await?.len() as usize, file);
let new_db = Cache::new();
writer.write_all(&new_db.to_bytes().await?).await?;
writer.shutdown().await?;
new_db
} else {
let file = TokioFile::open(&path).await?;
#[cfg(not(miri))]
file.lock_exclusive()?;
let mut buf = Vec::new();
let mut reader = BufReader::with_capacity(file.metadata().await?.len() as usize, file);
reader.read_to_end(&mut buf).await?;
let mut db = deserialize::<Cache>(&buf)?;
db.migrate();
reader.into_inner().shutdown().await?;
db
};
let (sx, rx) = channel(3000);
let db: &_ = Box::leak(Box::new(Self {
path: path.as_ref().to_owned(),
memory_cache: cache,
watchdog_sx: sx,
watchdog_rx: Mutex::new(rx),
}));
spawn(db.watchdog());
Ok(db)
}
pub async fn notify_watchdog(&self, msg: WatchdogMessage) {
unsafe { self.watchdog_sx.send(msg).await.unwrap_unchecked() }
}
async fn watchdog(&self) {
let mut receiver = self.watchdog_rx.lock().await;
// receiver.recv().await;
loop {
match timeout(Duration::from_millis(2 * 60 * 1000), receiver.recv()).await {
Ok(Some(WatchdogMessage::Ping)) => continue,
Ok(Some(WatchdogMessage::Force)) | Err(_) if self.flush().await.is_ok() => {
while receiver.try_recv().is_ok() {}
receiver.recv().await;
}
_ => {
panic!("Database exploded");
}
}
}
}
async fn flush(&self) -> Result<()> {
// fixme switch to [`UringFile`] when [`tokio_uring`] becomes Send
remove_file(&self.path).await?;
let file = TokioFile::create(&self.path).await?;
let mut writer = BufWriter::with_capacity(file.metadata().await?.len() as usize, file);
writer
.write_all(&self.memory_cache.to_bytes().await?)
.await?;
writer.flush().await?;
writer.shutdown().await?;
Ok(())
}
}
#[repr(transparent)]
#[derive(Default, Debug)]
pub struct DBList<T>(pub DashMap<u128, RwLock<T>>)
where
T: Clone + Serialize + for<'a> Deserialize<'a>;
impl<T> Serialize for DBList<T>
where
T: Clone + Serialize + for<'a> Deserialize<'a>,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
Serialize::serialize(
&self
.0
.iter()
.map(|x| (*x.key(), x.value().blocking_read().clone()))
.collect::<DashMap<u128, T>>(),
serializer,
)
}
}
impl<'de, T> Deserialize<'de> for DBList<T>
where
T: Clone + Serialize + for<'a> Deserialize<'a>,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let val: DashMap<u128, T> = Deserialize::deserialize(deserializer)?;
Ok(Self(
val.into_iter().map(|(i, x)| (i, RwLock::new(x))).collect(),
))
}
}
impl<T> DBList<T>
where
T: Clone + Serialize + for<'a> Deserialize<'a>,
{
pub fn insert(&self, element: T) -> u128 {
let uuid = rand::random();
self.0.insert(uuid, RwLock::new(element));
uuid
}
}
pub(crate) mod macros {
#![allow(dead_code)]
#[macro_export]
macro_rules! db {
() => {
#[allow(unused_unsafe)]
unsafe {
*$crate::DATABASE.get().unwrap_unchecked()
}
};
(Table::$x:ident) => {
db!(::Table::$x).0
};
(::Table::$x:ident) => {
#[allow(unused_unsafe)]
unsafe {
db!().memory_cache[DBTable::$x($crate::database::macros::zero()).tag()]
.get_table::<$x>()
}
};
}
/// HACK this is used to build an enum with the correct variant
/// on the [`db`] macro and get the discriminant. The optimizer
/// probably will remove the function call since its return
/// value is never read.
#[inline(always)]
pub const unsafe fn zero<T>() -> T {
MaybeUninit::zeroed().assume_init()
}
use super::*;
pub type Alumnes = Alumne;
pub type Classes = Classe;
pub type Cursos = Curs;
pub type Professors = Professor;
pub type ProfTokens = u128;
pub type TutorLegals = TutorLegal;
pub type TutorsLegalsTokens = u128;
pub type GlobalRw = ();
pub type GlobalRo = ();
}
pub use macros::*;
#[cfg(test)]
mod tests {
use crate::{db, DATABASE};
use tokio::{fs::remove_file, time::sleep};
use super::*;
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn read_write_db() {
remove_file("./temp_db").await.unwrap_or(());
DB::open("./temp_db")
.await
.and_then(|db| DATABASE.set(db).map_err(|e| anyhow!("{e:?}")))
.expect("Can't open the database");
const INSTANCES: usize = 4;
for _ in 0..INSTANCES {
db!(Table::Alumnes).insert(rand::random::<u128>(), RwLock::new(Alumne::default()));
}
db!().notify_watchdog(WatchdogMessage::Force).await;
sleep(Duration::from_millis(1000)).await;
DB::open("./temp_db")
.await
.inspect(|&db| {
assert_eq!(
unsafe {
db.memory_cache[DBTable::Alumnes(zero()).tag()]
.get_table::<Alumne>()
.0
.len()
},
INSTANCES
)
})
.expect("Can't open the database");
remove_file("./temp_db").await.unwrap_or(());
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment