Skip to content

Instantly share code, notes, and snippets.

@westonpace
Created May 19, 2023 12:55
Show Gist options
  • Save westonpace/6f7fdbdc0399501418101851d75091c4 to your computer and use it in GitHub Desktop.
Save westonpace/6f7fdbdc0399501418101851d75091c4 to your computer and use it in GitHub Desktop.
Example writing data to Acero
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <arrow/acero/api.h>
#include <arrow/array/builder_binary.h>
#include <arrow/array/builder_primitive.h>
#include <arrow/compute/api.h>
#include <arrow/dataset/api.h>
#include <arrow/dataset/plan.h>
#include <arrow/filesystem/api.h>
#include <arrow/result.h>
#include <arrow/status.h>
#include <arrow/table.h>
#include <chrono>
#include <condition_variable>
#include <iostream>
#include <memory>
#include <mutex>
#include <queue>
#include <thread>
namespace {
// Utility functions for converting arrow::Result & arrow::Status into
// exceptions
template <typename T>
T throw_or_assign(arrow::Result<T> value) {
if (!value.ok()) {
throw std::runtime_error(value.status().ToString());
}
return value.MoveValueUnsafe();
}
void throw_not_ok(const arrow::Status &status) {
if (!status.ok()) {
throw std::runtime_error(status.ToString());
}
}
// Our sample event which arrives periodically
struct Event {
std::string location;
int32_t id;
};
// The schema for our table
std::shared_ptr<arrow::Schema> event_schema() {
static std::shared_ptr<arrow::Schema> kEventSchema =
arrow::schema({arrow::field("location", arrow::utf8()),
arrow::field("id", arrow::int32())});
return kEventSchema;
}
// A queue to convert from a push model to a pull model.
//
// This is a fairly basic run-of-the-mill producer-consumer queue
//
// As data arrives, it is put in a queue. When the reader
// is polled it pulls from the queue. In a production application
// you would likely want some kind of backpressure if it is possible
// that data will arrive faster than it can be written.
class Queue : public arrow::RecordBatchReader {
public:
std::shared_ptr<arrow::Schema> schema() const override {
return event_schema();
}
void Push(std::shared_ptr<arrow::RecordBatch> batch) {
std::lock_guard lg(mutex_);
overflow_.push(std::move(batch));
data_available_.notify_one();
}
void Finish() {
std::lock_guard lg(mutex_);
finished_ = true;
data_available_.notify_one();
}
arrow::Status ReadNext(std::shared_ptr<arrow::RecordBatch> *batch) override {
std::unique_lock lk(mutex_);
data_available_.wait(lk, [&] { return finished_ || !overflow_.empty(); });
if (!overflow_.empty()) {
std::shared_ptr<arrow::RecordBatch> next = overflow_.front();
overflow_.pop();
*batch = next;
} else {
// finsihed_ must be true, send null to signal end of stream
*batch = nullptr;
}
return arrow::Status::OK();
}
private:
std::queue<std::shared_ptr<arrow::RecordBatch>> overflow_;
bool finished_ = false;
std::mutex mutex_;
std::condition_variable data_available_;
};
// Since data is arriving a row at a time we wait until we've accumulated a
// sufficient number of rows before pushing it into our queue
class Accumulator {
public:
Accumulator(int batch_size, Queue *finished_batches)
: batch_size_(batch_size), finished_batches_queue_(finished_batches) {}
void AddRow(const std::string &location) {
throw_not_ok(location_builder_.Append(location));
throw_not_ok(id_builder_.Append(current_id_++));
if (location_builder_.length() == batch_size_) {
PushNext();
}
}
void Finish() {
if (location_builder_.length() > 0) {
PushNext();
}
finished_batches_queue_->Finish();
}
private:
void PushNext() {
std::shared_ptr<arrow::Array> locations =
throw_or_assign(location_builder_.Finish());
std::shared_ptr<arrow::Array> ids = throw_or_assign(id_builder_.Finish());
int num_rows = locations->length();
std::shared_ptr<arrow::RecordBatch> batch = arrow::RecordBatch::Make(
event_schema(), num_rows, {std::move(locations), std::move(ids)});
finished_batches_queue_->Push(std::move(batch));
location_builder_.Reset();
id_builder_.Reset();
}
const int batch_size_;
Queue *finished_batches_queue_;
int current_id_ = 0;
arrow::StringBuilder location_builder_;
arrow::Int32Builder id_builder_;
};
// Write as many rows as we can in 5 seconds, sleeping for just
// a tiny bit between each one.
void WriteData(Accumulator *accumulator) {
std::array locations{"Denver", "Atlanta", "Chicago", "New York",
"Sacramento"};
int location_idx = 0;
auto start = std::chrono::high_resolution_clock::now();
while (std::chrono::high_resolution_clock::now() - start <
std::chrono::seconds(5)) {
accumulator->AddRow(locations[location_idx]);
location_idx++;
if (location_idx == static_cast<int>(locations.size())) {
location_idx = 0;
}
std::this_thread::sleep_for(std::chrono::microseconds(10));
}
accumulator->Finish();
}
// This starts the plan asynchronously. There will still be a thread used.
// That one thread will read from the record batch reader until we are done.
arrow::Future<> StartPlan(std::shared_ptr<arrow::RecordBatchReader> source) {
arrow::acero::RecordBatchReaderSourceNodeOptions source_opts(
std::move(source));
auto parquet_format = std::make_shared<arrow::dataset::ParquetFileFormat>();
std::shared_ptr<arrow::Schema> partition_schema =
throw_or_assign(event_schema()->RemoveField(1));
arrow::dataset::FileSystemDatasetWriteOptions write_opts;
write_opts.base_dir = "/tmp/my_dataset";
write_opts.basename_template = "chunk_{i}.parquet";
write_opts.filesystem = std::make_shared<arrow::fs::LocalFileSystem>();
write_opts.file_write_options = parquet_format->DefaultWriteOptions();
write_opts.partitioning =
std::make_shared<arrow::dataset::HivePartitioning>(partition_schema);
arrow::dataset::WriteNodeOptions write_node_opts(write_opts);
arrow::acero::Declaration plan = arrow::acero::Declaration::Sequence(
{{"record_batch_reader_source", source_opts},
{"write", write_node_opts}});
return arrow::acero::DeclarationToStatusAsync(plan, /*use_threads=*/false);
}
void RunTest() {
arrow::dataset::internal::Initialize();
auto queue = std::make_shared<Queue>();
Accumulator accumulator(100, queue.get());
arrow::Future<> plan_fut = StartPlan(queue);
WriteData(&accumulator);
throw_not_ok(plan_fut.status());
}
} // namespace
int main() {
try {
RunTest();
} catch (std::runtime_error &err) {
std::cerr << "An error occurred: " << err.what() << std::endl;
return 1;
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment