Created
November 19, 2022 16:51
-
-
Save avantgardnerio/60a53a3481f3d7844efa12346ef6f814 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use crate::utils::{err, ColDbError}; | |
use datafusion::arrow::array::{ArrayRef, Int32Array, StringArray}; | |
use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef}; | |
use datafusion::arrow::record_batch::RecordBatch; | |
use std::collections::HashMap; | |
use std::ptr::slice_from_raw_parts_mut; | |
use std::sync::Arc; | |
use tokio::sync::Semaphore; | |
pub struct AppendableRecordBatch { | |
record_batch: RecordBatch, | |
len: usize, | |
col_bytes: Vec<usize>, | |
flag: Semaphore, | |
} | |
impl AppendableRecordBatch { | |
pub fn new(schema: Schema) -> Result<Self, ColDbError> { | |
let mut data: Vec<ArrayRef> = vec![]; | |
for f in schema.fields().iter() { | |
let col: ArrayRef = match f.data_type() { | |
DataType::Int32 => Arc::new(Int32Array::from([0i32; 1000].to_vec())), | |
// TODO: pre-allocate reasonable guess at string len | |
DataType::Utf8 => { | |
let str = " ".repeat(100); | |
let vec = [str.as_str(); 1000].to_vec(); | |
Arc::new(StringArray::from(vec)) | |
} | |
_ => Err(err( | |
format!("Unsupported type: {:?}", f.data_type()).as_str() | |
))?, | |
}; | |
data.push(col); | |
} | |
let col_bytes = (0..schema.fields.len()).map(|_| 0).collect(); | |
let rb = RecordBatch::try_new(Arc::new(schema), data)?; | |
Ok(Self { | |
record_batch: rb, | |
len: 0, | |
col_bytes, | |
flag: Semaphore::new(1), | |
}) | |
} | |
pub fn schema(&self) -> SchemaRef { | |
self.record_batch.schema() | |
} | |
pub fn append_row(&mut self, map: HashMap<String, String>) -> Result<(), ColDbError> { | |
// TODO: bounds checks everywhere | |
let _lock = self.flag.acquire(); | |
for (idx, field) in self.schema().fields.iter().enumerate() { | |
let val = map.get(field.name()).ok_or(err("Value not found!"))?; | |
let ar = self.record_batch.column(idx); | |
let buf = ar | |
.data() | |
.buffers() | |
.get(0) | |
.ok_or(err("Invalid expression!"))?; | |
match field.data_type() { | |
DataType::Int32 => { | |
let buf = buf.as_ptr() as *mut i32; // casting const to mut | |
let i = val.parse().map_err(|_| err("Can't parse value!"))?; | |
// rust badness everyone will hate | |
unsafe { | |
*buf.add(self.len) = i; | |
} | |
} | |
DataType::Utf8 => { | |
let offsets = buf.as_ptr() as *mut i32; | |
let val_buf = ar | |
.data() | |
.buffers() | |
.get(1) | |
.ok_or(err("Invalid expression!"))?; | |
let values = val_buf.as_ptr() as *mut u8; | |
let len = self.col_bytes.get_mut(idx).ok_or(err("Invalid col idx"))?; | |
unsafe { | |
let array = slice_from_raw_parts_mut(values.add(*len), val.len()) | |
.as_mut() | |
.ok_or(err("Invalid cast"))?; | |
array.copy_from_slice(val.as_bytes()); | |
*len += val.len(); | |
*offsets.add(self.len + 1) = *len as i32; | |
} | |
} | |
_ => Err(err("Can't set type!"))?, | |
} | |
} | |
self.len += 1; | |
Ok(()) | |
} | |
pub fn len(&self) -> usize { | |
self.len | |
} | |
pub fn as_slice(&self) -> RecordBatch { | |
self.record_batch.slice(0, self.len()) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment