Created
November 19, 2022 19:51
-
-
Save avantgardnerio/48d977ea6bd28c790cfb6df09250336d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use crate::utils::{err, ColDbError}; | |
use datafusion::arrow::array::{ArrayRef, Int32Array, StringArray}; | |
use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef}; | |
use datafusion::arrow::record_batch::RecordBatch; | |
use std::ptr::slice_from_raw_parts_mut; | |
use std::sync::Arc; | |
use tokio::sync::Semaphore; | |
pub struct AppendableRecordBatch { | |
record_batch: RecordBatch, | |
len: usize, | |
col_bytes: Vec<usize>, | |
flag: Semaphore, | |
} | |
impl AppendableRecordBatch { | |
pub fn new(schema: Schema) -> Result<Self, ColDbError> { | |
let mut data: Vec<ArrayRef> = vec![]; | |
for f in schema.fields().iter() { | |
let col: ArrayRef = match f.data_type() { | |
DataType::Int32 => Arc::new(Int32Array::from([0i32; 1000].to_vec())), | |
// TODO: pre-allocate reasonable guess at string len | |
DataType::Utf8 => { | |
let str = " ".repeat(100); | |
let vec = [str.as_str(); 1000].to_vec(); | |
Arc::new(StringArray::from(vec)) | |
} | |
_ => Err(err( | |
format!("Unsupported type: {:?}", f.data_type()).as_str() | |
))?, | |
}; | |
data.push(col); | |
} | |
let col_bytes = (0..schema.fields.len()).map(|_| 0).collect(); | |
let rb = RecordBatch::try_new(Arc::new(schema), data)?; | |
Ok(Self { | |
record_batch: rb, | |
len: 0, | |
col_bytes, | |
flag: Semaphore::new(1), | |
}) | |
} | |
pub fn schema(&self) -> SchemaRef { | |
self.record_batch.schema() | |
} | |
pub fn extend(&mut self, source: RecordBatch) -> Result<(), ColDbError> { | |
// TODO: bounds checks everywhere, schema match check | |
let _lock = self.flag.acquire(); | |
for row_idx in 0..source.num_rows() { | |
for (idx, field) in self.schema().fields.iter().enumerate() { | |
let ar = self.record_batch.column(idx); | |
let buf = ar | |
.data() | |
.buffers() | |
.get(0) | |
.ok_or(err("Invalid expression!"))?; | |
match field.data_type() { | |
DataType::Int32 => { | |
let src_col = source | |
.column(idx) | |
.as_any() | |
.downcast_ref::<Int32Array>() | |
.ok_or(err("Types don't match"))?; | |
let val = src_col.value(row_idx); | |
let buf = buf.as_ptr() as *mut i32; // casting const to mut | |
// rust badness everyone will hate | |
unsafe { | |
*buf.add(self.len) = val; | |
} | |
} | |
DataType::Utf8 => { | |
let src_col = source | |
.column(idx) | |
.as_any() | |
.downcast_ref::<StringArray>() | |
.ok_or(err("Types don't match"))?; | |
let val = src_col.value(row_idx); | |
let offsets = buf.as_ptr() as *mut i32; | |
let val_buf = ar | |
.data() | |
.buffers() | |
.get(1) | |
.ok_or(err("Invalid expression!"))?; | |
let values = val_buf.as_ptr() as *mut u8; | |
let len = self.col_bytes.get_mut(idx).ok_or(err("Invalid col idx"))?; | |
unsafe { | |
let array = slice_from_raw_parts_mut(values.add(*len), val.len()) | |
.as_mut() | |
.ok_or(err("Invalid cast"))?; | |
array.copy_from_slice(val.as_bytes()); | |
*len += val.len(); | |
*offsets.add(self.len + 1) = *len as i32; | |
} | |
} | |
_ => Err(err("Can't set type!"))?, | |
} | |
} | |
self.len += 1; | |
} | |
Ok(()) | |
} | |
pub fn len(&self) -> usize { | |
self.len | |
} | |
pub fn as_slice(&self) -> RecordBatch { | |
self.record_batch.slice(0, self.len()) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment