Skip to content

Instantly share code, notes, and snippets.

@avantgardnerio
Created November 19, 2022 16:51
Show Gist options
  • Save avantgardnerio/60a53a3481f3d7844efa12346ef6f814 to your computer and use it in GitHub Desktop.
Save avantgardnerio/60a53a3481f3d7844efa12346ef6f814 to your computer and use it in GitHub Desktop.
use crate::utils::{err, ColDbError};
use datafusion::arrow::array::{ArrayRef, Int32Array, StringArray};
use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
use datafusion::arrow::record_batch::RecordBatch;
use std::collections::HashMap;
use std::ptr::slice_from_raw_parts_mut;
use std::sync::Arc;
use tokio::sync::Semaphore;
pub struct AppendableRecordBatch {
record_batch: RecordBatch,
len: usize,
col_bytes: Vec<usize>,
flag: Semaphore,
}
impl AppendableRecordBatch {
pub fn new(schema: Schema) -> Result<Self, ColDbError> {
let mut data: Vec<ArrayRef> = vec![];
for f in schema.fields().iter() {
let col: ArrayRef = match f.data_type() {
DataType::Int32 => Arc::new(Int32Array::from([0i32; 1000].to_vec())),
// TODO: pre-allocate reasonable guess at string len
DataType::Utf8 => {
let str = " ".repeat(100);
let vec = [str.as_str(); 1000].to_vec();
Arc::new(StringArray::from(vec))
}
_ => Err(err(
format!("Unsupported type: {:?}", f.data_type()).as_str()
))?,
};
data.push(col);
}
let col_bytes = (0..schema.fields.len()).map(|_| 0).collect();
let rb = RecordBatch::try_new(Arc::new(schema), data)?;
Ok(Self {
record_batch: rb,
len: 0,
col_bytes,
flag: Semaphore::new(1),
})
}
pub fn schema(&self) -> SchemaRef {
self.record_batch.schema()
}
pub fn append_row(&mut self, map: HashMap<String, String>) -> Result<(), ColDbError> {
// TODO: bounds checks everywhere
let _lock = self.flag.acquire();
for (idx, field) in self.schema().fields.iter().enumerate() {
let val = map.get(field.name()).ok_or(err("Value not found!"))?;
let ar = self.record_batch.column(idx);
let buf = ar
.data()
.buffers()
.get(0)
.ok_or(err("Invalid expression!"))?;
match field.data_type() {
DataType::Int32 => {
let buf = buf.as_ptr() as *mut i32; // casting const to mut
let i = val.parse().map_err(|_| err("Can't parse value!"))?;
// rust badness everyone will hate
unsafe {
*buf.add(self.len) = i;
}
}
DataType::Utf8 => {
let offsets = buf.as_ptr() as *mut i32;
let val_buf = ar
.data()
.buffers()
.get(1)
.ok_or(err("Invalid expression!"))?;
let values = val_buf.as_ptr() as *mut u8;
let len = self.col_bytes.get_mut(idx).ok_or(err("Invalid col idx"))?;
unsafe {
let array = slice_from_raw_parts_mut(values.add(*len), val.len())
.as_mut()
.ok_or(err("Invalid cast"))?;
array.copy_from_slice(val.as_bytes());
*len += val.len();
*offsets.add(self.len + 1) = *len as i32;
}
}
_ => Err(err("Can't set type!"))?,
}
}
self.len += 1;
Ok(())
}
pub fn len(&self) -> usize {
self.len
}
pub fn as_slice(&self) -> RecordBatch {
self.record_batch.slice(0, self.len())
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment