Created
May 5, 2024 02:32
-
-
Save mooreniemi/0bb2c34cbe991456d55e0f85260f4a37 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[package] | |
name = "poodah" | |
version = "0.1.0" | |
edition = "2021" | |
[dependencies] | |
arrow = "25.0.0" | |
parquet = "25.0.0" | |
rusoto_core = "0.48.0" | |
rusoto_mock = "0.48.0" | |
rusoto_s3 = "0.48.0" | |
tokio = { version = "1.21.2", features = ["full"] } | |
mockall = "*" | |
bytes = "*" | |
url = "*" | |
rayon = "*" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use arrow::array::{Array, ArrayRef, Float64Array, StringArray}; | |
use arrow::datatypes::{DataType, Field, Schema}; | |
use arrow::record_batch::RecordBatch; | |
use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder; | |
use parquet::arrow::ArrowWriter; | |
use parquet::file::reader::{FileReader, SerializedFileReader}; | |
use rayon::prelude::*; | |
use std::cmp::Reverse; | |
use std::cmp::{self, Ordering}; | |
use std::collections::BinaryHeap; | |
use std::env; | |
use std::fs::File; | |
use std::path::Path; | |
use std::sync::{Arc, Mutex}; | |
use std::time::Instant; | |
use url::Url; | |
#[derive(PartialEq)] | |
struct MinNonNan(f64); | |
impl Eq for MinNonNan {} | |
impl PartialOrd for MinNonNan { | |
fn partial_cmp(&self, other: &Self) -> Option<Ordering> { | |
other.0.partial_cmp(&self.0) | |
} | |
} | |
impl Ord for MinNonNan { | |
fn cmp(&self, other: &MinNonNan) -> Ordering { | |
self.partial_cmp(other).unwrap() | |
} | |
} | |
fn extract_domain(uri: &str) -> String { | |
let parsed_url = Url::parse(uri).unwrap(); | |
parsed_url.domain().unwrap().to_string() | |
} | |
fn total_rows_in_parquet(file_path: &str) -> Result<i64, Box<dyn std::error::Error>> { | |
let file = File::open(Path::new(file_path))?; | |
let reader = SerializedFileReader::new(file)?; | |
let metadata = reader.metadata(); | |
let num_rows = metadata.file_metadata().num_rows(); | |
Ok(num_rows) | |
} | |
fn main() -> Result<(), Box<dyn std::error::Error>> { | |
// get command line arguments | |
let args: Vec<String> = env::args().collect(); | |
let shard_num: usize = args.get(1).unwrap_or(&"0".to_string()).parse().unwrap_or(0); | |
let output_rows: usize = args | |
.get(2) | |
.unwrap_or(&"5000".to_string()) | |
.parse() | |
.unwrap_or(5000); | |
let batch_size: usize = args | |
.get(3) | |
.unwrap_or(&"1024".to_string()) | |
.parse() | |
.unwrap_or(1024); | |
// file paths | |
let input_path = format!("/tmp/shards/shard_num={}/data.parquet", shard_num); | |
let output_path = format!("/tmp/output/shard_num={}/data.parquet", shard_num); | |
let file = File::open(input_path.clone())?; | |
// just read the metadata here | |
let total_rows = total_rows_in_parquet(&input_path).expect("got total rows"); | |
println!( | |
"shard_num: {}, output_rows: {}, total_rows: {}", | |
shard_num, output_rows, total_rows | |
); | |
// we want to round up or our total output rows will be a bit high | |
let num_batches = total_rows | |
.checked_div(batch_size as i64) | |
.map(|div_result| { | |
div_result | |
+ if total_rows % batch_size as i64 != 0 { | |
1 | |
} else { | |
0 | |
} | |
}) | |
.unwrap_or(0); | |
// handle the case where we have fewer output rows than batches | |
let per_batch_max = cmp::max(output_rows / num_batches as usize, 1); | |
println!( | |
"num_batches: {}, per_batch_max: {}", | |
num_batches, per_batch_max | |
); | |
let builder = ParquetRecordBatchReaderBuilder::try_new(file).unwrap(); | |
println!("Converted arrow schema is: {}", builder.schema()); | |
// default is 1024 | |
let mut arrow_reader = builder.with_batch_size(batch_size).build().unwrap(); | |
// input_schema read from the file itself | |
let output_schema = Arc::new(Schema::new(vec![ | |
Field::new("qps", DataType::Float64, false), | |
Field::new("uid", DataType::Utf8, false), | |
Field::new("uri", DataType::Utf8, false), | |
Field::new("domain", DataType::Utf8, false), | |
])); | |
let all_batches_start = Instant::now(); | |
let new_file = File::create(output_path)?; | |
// example of metadata, must be defined before writing begins | |
let metadata = parquet::file::metadata::KeyValue::new("key".to_string(), "value".to_string()); | |
let props = parquet::file::properties::WriterProperties::builder() | |
.set_key_value_metadata(Some(vec![metadata])) | |
.build(); | |
let mut writer = ArrowWriter::try_new(new_file, output_schema.clone(), Some(props))?; | |
let mut total = 0; | |
let mut total_batches = 0; | |
while let Some(maybe_batch) = arrow_reader.next() { | |
total_batches += 1; | |
let mut min_heap = BinaryHeap::new(); | |
let start = Instant::now(); | |
let record_batch = maybe_batch?; | |
//dbg!(&record_batch); | |
let qps_score_column = record_batch | |
.column(0) | |
.as_any() | |
.downcast_ref::<Float64Array>() | |
.expect("qps score in first column"); | |
for i in 0..record_batch.num_rows() { | |
let qps_score = qps_score_column.value(i); | |
let current_score = Reverse(MinNonNan(qps_score)); | |
if min_heap.len() < per_batch_max { | |
min_heap.push(current_score); | |
} else if current_score < *min_heap.peek().unwrap() { | |
min_heap.pop(); | |
min_heap.push(current_score); | |
} | |
} | |
let min_score = min_heap.peek().unwrap().0 .0; | |
//dbg!("batch min_score: {}", min_score); | |
let uid_column = record_batch | |
.column(1) | |
.as_any() | |
.downcast_ref::<StringArray>() | |
.unwrap(); | |
let uri_column = record_batch | |
.column(2) | |
.as_any() | |
.downcast_ref::<StringArray>() | |
.unwrap(); | |
let mut qps_array = Vec::with_capacity(output_rows); | |
let mut uid_array = Vec::with_capacity(output_rows); | |
let mut uri_array = Vec::with_capacity(output_rows); | |
let mut domain_array = Vec::with_capacity(output_rows); | |
for i in 0..record_batch.num_rows() { | |
let qps_score = qps_score_column.value(i); | |
if qps_score <= min_score { | |
total += 1; | |
qps_array.push(qps_score); | |
let uid = uid_column.value(i); | |
uid_array.push(uid); | |
let uri = uri_column.value(i); | |
uri_array.push(uri); | |
domain_array.push(extract_domain(uri)); | |
} | |
} | |
let qps_array: ArrayRef = Arc::new(Float64Array::from(qps_array)); | |
let uid_array: ArrayRef = Arc::new(StringArray::from(uid_array)); | |
let uri_array: ArrayRef = Arc::new(StringArray::from(uri_array)); | |
let domain_array: ArrayRef = Arc::new(StringArray::from(domain_array)); | |
let new_batch = RecordBatch::try_new( | |
output_schema.clone(), | |
vec![qps_array, uid_array, uri_array, domain_array], | |
)?; | |
writer.write(&new_batch)?; | |
let duration = start.elapsed(); | |
//println!("batch took: {}s", duration.as_secs()); | |
} | |
writer.close()?; | |
let duration = all_batches_start.elapsed(); | |
println!( | |
"from {} rows all {} batches took: {}s, output rows count: {}", | |
total_rows, | |
total_batches, | |
duration.as_secs(), | |
total | |
); | |
Ok(()) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import numpy as np | |
num_rows = 1_000_000 | |
uids = [f"uid_{i}" for i in range(num_rows)] | |
uris = [f"https://example.com/path/{i}" for i in range(num_rows)] | |
qps = np.random.weibull(2, size=num_rows) | |
data = {"qps": qps, "uid": uids, "uri": uris} | |
df = pd.DataFrame(data) | |
df.to_parquet("/tmp/shards/shard_num=0/data.parquet", engine="pyarrow") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment