Skip to content

Instantly share code, notes, and snippets.

@avantgardnerio
Created November 19, 2022 19:51
Show Gist options
  • Save avantgardnerio/48d977ea6bd28c790cfb6df09250336d to your computer and use it in GitHub Desktop.
Save avantgardnerio/48d977ea6bd28c790cfb6df09250336d to your computer and use it in GitHub Desktop.
use crate::utils::{err, ColDbError};
use datafusion::arrow::array::{ArrayRef, Int32Array, StringArray};
use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
use datafusion::arrow::record_batch::RecordBatch;
use std::ptr::slice_from_raw_parts_mut;
use std::sync::Arc;
use tokio::sync::Semaphore;
pub struct AppendableRecordBatch {
record_batch: RecordBatch,
len: usize,
col_bytes: Vec<usize>,
flag: Semaphore,
}
impl AppendableRecordBatch {
pub fn new(schema: Schema) -> Result<Self, ColDbError> {
let mut data: Vec<ArrayRef> = vec![];
for f in schema.fields().iter() {
let col: ArrayRef = match f.data_type() {
DataType::Int32 => Arc::new(Int32Array::from([0i32; 1000].to_vec())),
// TODO: pre-allocate reasonable guess at string len
DataType::Utf8 => {
let str = " ".repeat(100);
let vec = [str.as_str(); 1000].to_vec();
Arc::new(StringArray::from(vec))
}
_ => Err(err(
format!("Unsupported type: {:?}", f.data_type()).as_str()
))?,
};
data.push(col);
}
let col_bytes = (0..schema.fields.len()).map(|_| 0).collect();
let rb = RecordBatch::try_new(Arc::new(schema), data)?;
Ok(Self {
record_batch: rb,
len: 0,
col_bytes,
flag: Semaphore::new(1),
})
}
pub fn schema(&self) -> SchemaRef {
self.record_batch.schema()
}
pub fn extend(&mut self, source: RecordBatch) -> Result<(), ColDbError> {
// TODO: bounds checks everywhere, schema match check
let _lock = self.flag.acquire();
for row_idx in 0..source.num_rows() {
for (idx, field) in self.schema().fields.iter().enumerate() {
let ar = self.record_batch.column(idx);
let buf = ar
.data()
.buffers()
.get(0)
.ok_or(err("Invalid expression!"))?;
match field.data_type() {
DataType::Int32 => {
let src_col = source
.column(idx)
.as_any()
.downcast_ref::<Int32Array>()
.ok_or(err("Types don't match"))?;
let val = src_col.value(row_idx);
let buf = buf.as_ptr() as *mut i32; // casting const to mut
// rust badness everyone will hate
unsafe {
*buf.add(self.len) = val;
}
}
DataType::Utf8 => {
let src_col = source
.column(idx)
.as_any()
.downcast_ref::<StringArray>()
.ok_or(err("Types don't match"))?;
let val = src_col.value(row_idx);
let offsets = buf.as_ptr() as *mut i32;
let val_buf = ar
.data()
.buffers()
.get(1)
.ok_or(err("Invalid expression!"))?;
let values = val_buf.as_ptr() as *mut u8;
let len = self.col_bytes.get_mut(idx).ok_or(err("Invalid col idx"))?;
unsafe {
let array = slice_from_raw_parts_mut(values.add(*len), val.len())
.as_mut()
.ok_or(err("Invalid cast"))?;
array.copy_from_slice(val.as_bytes());
*len += val.len();
*offsets.add(self.len + 1) = *len as i32;
}
}
_ => Err(err("Can't set type!"))?,
}
}
self.len += 1;
}
Ok(())
}
pub fn len(&self) -> usize {
self.len
}
pub fn as_slice(&self) -> RecordBatch {
self.record_batch.slice(0, self.len())
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment