Skip to content

Instantly share code, notes, and snippets.

@mooreniemi
Created May 5, 2024 02:32
Show Gist options
  • Save mooreniemi/0bb2c34cbe991456d55e0f85260f4a37 to your computer and use it in GitHub Desktop.
Save mooreniemi/0bb2c34cbe991456d55e0f85260f4a37 to your computer and use it in GitHub Desktop.
[package]
name = "poodah"
version = "0.1.0"
edition = "2021"
[dependencies]
arrow = "25.0.0"
parquet = "25.0.0"
rusoto_core = "0.48.0"
rusoto_mock = "0.48.0"
rusoto_s3 = "0.48.0"
tokio = { version = "1.21.2", features = ["full"] }
mockall = "*"
bytes = "*"
url = "*"
rayon = "*"
use arrow::array::{Array, ArrayRef, Float64Array, StringArray};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
use parquet::arrow::ArrowWriter;
use parquet::file::reader::{FileReader, SerializedFileReader};
use rayon::prelude::*;
use std::cmp::Reverse;
use std::cmp::{self, Ordering};
use std::collections::BinaryHeap;
use std::env;
use std::fs::File;
use std::path::Path;
use std::sync::{Arc, Mutex};
use std::time::Instant;
use url::Url;
#[derive(PartialEq)]
struct MinNonNan(f64);
impl Eq for MinNonNan {}
impl PartialOrd for MinNonNan {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
other.0.partial_cmp(&self.0)
}
}
impl Ord for MinNonNan {
fn cmp(&self, other: &MinNonNan) -> Ordering {
self.partial_cmp(other).unwrap()
}
}
fn extract_domain(uri: &str) -> String {
let parsed_url = Url::parse(uri).unwrap();
parsed_url.domain().unwrap().to_string()
}
fn total_rows_in_parquet(file_path: &str) -> Result<i64, Box<dyn std::error::Error>> {
let file = File::open(Path::new(file_path))?;
let reader = SerializedFileReader::new(file)?;
let metadata = reader.metadata();
let num_rows = metadata.file_metadata().num_rows();
Ok(num_rows)
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
// get command line arguments
let args: Vec<String> = env::args().collect();
let shard_num: usize = args.get(1).unwrap_or(&"0".to_string()).parse().unwrap_or(0);
let output_rows: usize = args
.get(2)
.unwrap_or(&"5000".to_string())
.parse()
.unwrap_or(5000);
let batch_size: usize = args
.get(3)
.unwrap_or(&"1024".to_string())
.parse()
.unwrap_or(1024);
// file paths
let input_path = format!("/tmp/shards/shard_num={}/data.parquet", shard_num);
let output_path = format!("/tmp/output/shard_num={}/data.parquet", shard_num);
let file = File::open(input_path.clone())?;
// just read the metadata here
let total_rows = total_rows_in_parquet(&input_path).expect("got total rows");
println!(
"shard_num: {}, output_rows: {}, total_rows: {}",
shard_num, output_rows, total_rows
);
// we want to round up or our total output rows will be a bit high
let num_batches = total_rows
.checked_div(batch_size as i64)
.map(|div_result| {
div_result
+ if total_rows % batch_size as i64 != 0 {
1
} else {
0
}
})
.unwrap_or(0);
// handle the case where we have fewer output rows than batches
let per_batch_max = cmp::max(output_rows / num_batches as usize, 1);
println!(
"num_batches: {}, per_batch_max: {}",
num_batches, per_batch_max
);
let builder = ParquetRecordBatchReaderBuilder::try_new(file).unwrap();
println!("Converted arrow schema is: {}", builder.schema());
// default is 1024
let mut arrow_reader = builder.with_batch_size(batch_size).build().unwrap();
// input_schema read from the file itself
let output_schema = Arc::new(Schema::new(vec![
Field::new("qps", DataType::Float64, false),
Field::new("uid", DataType::Utf8, false),
Field::new("uri", DataType::Utf8, false),
Field::new("domain", DataType::Utf8, false),
]));
let all_batches_start = Instant::now();
let new_file = File::create(output_path)?;
// example of metadata, must be defined before writing begins
let metadata = parquet::file::metadata::KeyValue::new("key".to_string(), "value".to_string());
let props = parquet::file::properties::WriterProperties::builder()
.set_key_value_metadata(Some(vec![metadata]))
.build();
let mut writer = ArrowWriter::try_new(new_file, output_schema.clone(), Some(props))?;
let mut total = 0;
let mut total_batches = 0;
while let Some(maybe_batch) = arrow_reader.next() {
total_batches += 1;
let mut min_heap = BinaryHeap::new();
let start = Instant::now();
let record_batch = maybe_batch?;
//dbg!(&record_batch);
let qps_score_column = record_batch
.column(0)
.as_any()
.downcast_ref::<Float64Array>()
.expect("qps score in first column");
for i in 0..record_batch.num_rows() {
let qps_score = qps_score_column.value(i);
let current_score = Reverse(MinNonNan(qps_score));
if min_heap.len() < per_batch_max {
min_heap.push(current_score);
} else if current_score < *min_heap.peek().unwrap() {
min_heap.pop();
min_heap.push(current_score);
}
}
let min_score = min_heap.peek().unwrap().0 .0;
//dbg!("batch min_score: {}", min_score);
let uid_column = record_batch
.column(1)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let uri_column = record_batch
.column(2)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let mut qps_array = Vec::with_capacity(output_rows);
let mut uid_array = Vec::with_capacity(output_rows);
let mut uri_array = Vec::with_capacity(output_rows);
let mut domain_array = Vec::with_capacity(output_rows);
for i in 0..record_batch.num_rows() {
let qps_score = qps_score_column.value(i);
if qps_score <= min_score {
total += 1;
qps_array.push(qps_score);
let uid = uid_column.value(i);
uid_array.push(uid);
let uri = uri_column.value(i);
uri_array.push(uri);
domain_array.push(extract_domain(uri));
}
}
let qps_array: ArrayRef = Arc::new(Float64Array::from(qps_array));
let uid_array: ArrayRef = Arc::new(StringArray::from(uid_array));
let uri_array: ArrayRef = Arc::new(StringArray::from(uri_array));
let domain_array: ArrayRef = Arc::new(StringArray::from(domain_array));
let new_batch = RecordBatch::try_new(
output_schema.clone(),
vec![qps_array, uid_array, uri_array, domain_array],
)?;
writer.write(&new_batch)?;
let duration = start.elapsed();
//println!("batch took: {}s", duration.as_secs());
}
writer.close()?;
let duration = all_batches_start.elapsed();
println!(
"from {} rows all {} batches took: {}s, output rows count: {}",
total_rows,
total_batches,
duration.as_secs(),
total
);
Ok(())
}
import pandas as pd
import numpy as np
num_rows = 1_000_000
uids = [f"uid_{i}" for i in range(num_rows)]
uris = [f"https://example.com/path/{i}" for i in range(num_rows)]
qps = np.random.weibull(2, size=num_rows)
data = {"qps": qps, "uid": uids, "uri": uris}
df = pd.DataFrame(data)
df.to_parquet("/tmp/shards/shard_num=0/data.parquet", engine="pyarrow")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment