Skip to content

Instantly share code, notes, and snippets.

@imvladikon
Last active January 29, 2023 01:56
Show Gist options
  • Save imvladikon/4315752e2b244c7dbc765fb66df99fe5 to your computer and use it in GitHub Desktop.
Save imvladikon/4315752e2b244c7dbc765fb66df99fe5 to your computer and use it in GitHub Desktop.
duckdb + huggingface datasets
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import duckdb
import pyarrow as pa
from datasets import Dataset
try:
from ibis.backends.base.sql.alchemy import AlchemyTable
IBIS_AVAILABLE = True
except:
IBIS_AVAILABLE = False
class DatasetQuery:
"""
duckdb wrapper for executing queries over huggingface datasets
"""
def __init__(self, arrow_table, table_name="arrow_table"):
self.table_name = table_name
self.connection = duckdb.connect(":memory:")
self.connection.register(table_name, arrow_table)
def query(self, query: str) -> "Dataset":
result = self.connection.query(query).to_arrow_table()
with pa.BufferOutputStream() as buf_writer, pa.RecordBatchStreamWriter(
buf_writer, schema=result.schema
) as writer:
writer.write_table(result)
result = Dataset.from_buffer(buf_writer.getvalue())
return result
@classmethod
def from_hf_dataset(self, dataset: Dataset, **kwargs) -> "DatasetQuery":
arrow_table = dataset.data.table
return DatasetQuery(arrow_table, **kwargs)
if IBIS_AVAILABLE:
class IbisTableFactory(AlchemyTable):
"""
duckdb+ibis wrapper
"""
@staticmethod
def from_hf_dataset(dataset: Dataset, table_name: str) -> "pd.DataFrame":
import ibis
arrow_table = dataset.data.table
con = ibis.connect('duckdb://:memory:')
con.register(arrow_table, table_name=table_name)
table = con.table(table_name)
return table
if __name__ == '__main__':
dataset = Dataset.from_csv("data/insurance.csv")
dataset_query = DatasetQuery.from_hf_dataset(dataset, table_name="insurance_demo_table")
print(dataset_query.query("SELECT * FROM insurance_demo_table LIMIT 10").to_pandas())
if IBIS_AVAILABLE:
ibis_table = IbisTableFactory.from_hf_dataset(dataset, table_name="insurance_demo_table")
print(ibis_table.group_by("smoker").bmi.mean().execute())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment