Skip to content

Instantly share code, notes, and snippets.

@bnaul
Created July 15, 2021 17:23
Show Gist options
  • Save bnaul/4819f045ccbee160b60a530b6cfc0c98 to your computer and use it in GitHub Desktop.
Save bnaul/4819f045ccbee160b60a530b6cfc0c98 to your computer and use it in GitHub Desktop.
Dask <-> BigQuery helpers
def _stream_to_dfs(bqs_client, stream_name, schema, timeout):
"""Given a Storage API client and a stream name, yield all dataframes."""
return [
pyarrow.ipc.read_record_batch(
pyarrow.py_buffer(message.arrow_record_batch.serialized_record_batch), schema
).to_pandas()
for message in bqs_client.read_rows(name=stream_name, offset=0, timeout=timeout)
]
@dask.delayed
def _read_rows_arrow(
*,
make_create_read_session_request: callable,
partition_field: str = None,
project_id: str,
stream_name: str = None,
timeout: int,
) -> pd.DataFrame:
"""Read a single batch of rows via BQ Storage API, in Arrow binary format.
Args:
project_id: BigQuery project
create_read_session_request: kwargs to pass to `bqs_client.create_read_session` as `request`
partition_field: BigQuery field for partitions, to be used as Dask index col for divisions
NOTE: Please set if specifying `row_restriction` filters in TableReadOptions.
stream_name: BigQuery Storage API Stream "name".
NOTE: Please set if reading from Storage API without any `row_restriction`.
https://cloud.google.com/bigquery/docs/reference/storage/rpc/google.cloud.bigquery.storage.v1beta1#stream
NOTE: `partition_field` and `stream_name` kwargs are mutually exclusive.
Adapted from https://github.com/googleapis/python-bigquery-storage/blob/a0fc0af5b4447ce8b50c365d4d081b9443b8490e/google/cloud/bigquery_storage_v1/reader.py.
"""
with bigquery_client(project_id, with_storage_api=True) as (bq_client, bqs_client):
session = bqs_client.create_read_session(make_create_read_session_request())
schema = pyarrow.ipc.read_schema(pyarrow.py_buffer(session.arrow_schema.serialized_schema))
if (partition_field is not None) and (stream_name is not None):
raise ValueError(
"The kwargs `partition_field` and `stream_name` are mutually exclusive."
)
elif partition_field is not None:
shards = [
df
for stream in session.streams
for df in _stream_to_dfs(bqs_client, stream.name, schema, timeout=timeout)
]
# NOTE: if no rows satisfying the row_restriction, then `shards` will be empty list
if len(shards) == 0:
shards = [schema.empty_table().to_pandas()]
shards = [shard.set_index(partition_field, drop=True) for shard in shards]
elif stream_name is not None:
shards = _stream_to_dfs(bqs_client, stream_name, schema, timeout=timeout)
# NOTE: BQ Storage API can return empty streams
if len(shards) == 0:
shards = [schema.empty_table().to_pandas()]
else:
raise NotImplementedError("Please specify either `partition_field` or `stream_name`.")
return pd.concat(shards)
def gbq_as_dask_df(
project_id: str,
dataset_id: str,
table_id: str,
partition_field: str = None,
partitions: Iterable[str] = None,
row_filter="",
fields: List[str] = (),
read_timeout: int = 3600,
):
"""Read table as dask dataframe using BigQuery Storage API via Arrow format.
If `partition_field` and `partitions` are specified, then the resulting dask dataframe
will be partitioned along the same boundaries. Otherwise, partitions will be approximately
balanced according to BigQuery stream allocation logic.
If `partition_field` is specified but not included in `fields` (either implicitly by requesting
all fields, or explicitly by inclusion in the list `fields`), then it will still be included
in the query in order to have it available for dask dataframe indexing.
Args:
project_id: BigQuery project
dataset_id: BigQuery dataset within project
table_id: BigQuery table within dataset
partition_field: to specify filters of form "WHERE {partition_field} = ..."
partitions: all values to select of `partition_field`
fields: names of the fields (columns) to select (default None to "SELECT *")
read_timeout: # of seconds an individual read request has before timing out
Returns:
dask dataframe
See https://github.com/dask/dask/issues/3121 for additional context.
"""
if (partition_field is None) and (partitions is not None):
raise ValueError("Specified `partitions` without `partition_field`.")
# If `partition_field` is not part of the `fields` filter, fetch it anyway to be able
# to set it as dask dataframe index. We want this to be able to have consistent:
# BQ partitioning + dask divisions + pandas index values
if (partition_field is not None) and fields and (partition_field not in fields):
fields = (partition_field, *fields)
# These read tasks seems to cause deadlocks (or at least long stuck workers out of touch with
# the scheduler), particularly when mixed with other tasks that execute C code. Anecdotally
# annotating the tasks with a higher priority seems to help (but not fully solve) the issue at
# the expense of higher cluster memory usage.
with bigquery_client(project_id, with_storage_api=True) as (
bq_client,
bqs_client,
), dask.annotate(priority=1):
table_ref = bq_client.get_table(".".join((dataset_id, table_id)))
if table_ref.table_type == "VIEW":
# Materialize the view since the operations below don't work on views.
logging.warning("Materializing view in order to read into dask. This may be expensive.")
query = f"SELECT * FROM `{full_id(table_ref)}`"
table_ref, _, _ = execute_query(query)
# The protobuf types can't be pickled (may be able to tweak w/ copyreg), so instead use a
# generator func.
def make_create_read_session_request(row_filter=""):
return bigquery_storage.types.CreateReadSessionRequest(
max_stream_count=0, # 0 -> use as many streams as BQ Storage will provide
parent=f"projects/{project_id}",
read_session=bigquery_storage.types.ReadSession(
data_format=bigquery_storage.types.DataFormat.ARROW,
read_options=bigquery_storage.types.ReadSession.TableReadOptions(
row_restriction=row_filter,
selected_fields=fields,
),
table=table_ref.to_bqstorage(),
),
)
# Create a read session in order to detect the schema.
# Read sessions are light weight and will be auto-deleted after 24 hours.
session = bqs_client.create_read_session(
make_create_read_session_request(row_filter=row_filter)
)
schema = pyarrow.ipc.read_schema(pyarrow.py_buffer(session.arrow_schema.serialized_schema))
meta = schema.empty_table().to_pandas()
delayed_kwargs = dict(prefix=f"{dataset_id}.{table_id}-")
if partition_field is not None:
if row_filter:
raise ValueError("Cannot pass both `partition_field` and `row_filter`")
delayed_kwargs["meta"] = meta.set_index(partition_field, drop=True)
if partitions is None:
logging.info(
"Specified `partition_field` without `partitions`; reading full table."
)
partitions = read_gbq(
f"SELECT DISTINCT {partition_field} FROM {dataset_id}.{table_id}",
project_id=project_id,
)[partition_field].tolist()
# TODO generalize to ranges (as opposed to discrete values)
partitions = sorted(partitions)
delayed_kwargs["divisions"] = (*partitions, partitions[-1])
row_filters = [
f'{partition_field} = "{partition_value}"' for partition_value in partitions
]
delayed_dfs = [
_read_rows_arrow(
make_create_read_session_request=partial(
make_create_read_session_request, row_filter=row_filter
),
partition_field=partition_field,
project_id=project_id,
timeout=read_timeout,
)
for row_filter in row_filters
]
else:
delayed_kwargs["meta"] = meta
delayed_dfs = [
_read_rows_arrow(
make_create_read_session_request=make_create_read_session_request,
project_id=project_id,
stream_name=stream.name,
timeout=read_timeout,
)
for stream in session.streams
]
return dd.from_delayed(dfs=delayed_dfs, **delayed_kwargs)
def dask_df_to_gbq(
ddf: dd.DataFrame,
project_id: str = None,
dataset_id: str = None,
table_id: str = None,
bq_schema: List[bigquery.schema.SchemaField] = None,
pa_schema: pyarrow.Schema = None,
partition_by: str = None,
cluster_by: List[str] = None,
clear_existing: bool = True,
retries: int = None,
write_index: bool = False,
):
"""Upload dask dataframe to BigQuery using Storage API via Arrow format.
Args:
ddf: dask dataframe to upload
project_id: BigQuery project
dataset_id: BigQuery dataset within project
table_id: BigQuery table within dataset
bq_schema: resulting table schema
TODO infer from data; load_table_from_dataframe tries but issues w/ some types
pa_schema: parquet schema
partition_by: (date or timestamp) field to partition by
cluster_by: field to cluster by
clear_existing: whether to delete the existing table
retries: number of retries for dask computation
write_index: whether to write index in parquet
TODO: Change this to only write to GCS parquet pattern, and have the framework handle
downstream resolution into a BQ view of GCS
"""
dask_tmp_pattern = "gs://model_bigquery_tmp/dask_dataframe_tmp/{token}/{timestamp}/*.parquet"
dask_tmp_path = dask_tmp_pattern.format(token=tokenize(ddf), timestamp=int(1e6 * time.time()))
logging.info(f"Writing dask dataframe to {dask_tmp_path} ...")
ddf.to_parquet(
path=os.path.dirname(dask_tmp_path),
engine="pyarrow",
write_index=write_index,
write_metadata_file=False,
schema=pa_schema,
)
with bigquery_client(project_id) as bq_client:
if table_id:
if not dataset_id:
raise ValueError("Cannot pass table_id without dataset_id")
dataset_ref = bq_client.create_dataset(dataset_id, exists_ok=True)
table_ref = dataset_ref.table(table_id)
if clear_existing:
bq_client.delete_table(table_ref, not_found_ok=True)
else:
table_ref = get_temporary_table(bq_client)
logging.info("Loading to temporary table %s", table_ref.table_id)
logging.info(
"Loading %s to %s.%s.%s ...",
dask_tmp_path,
table_ref.project,
table_ref.dataset_id,
table_ref.table_id,
)
job_config = bigquery.LoadJobConfig(
clustering_fields=cluster_by,
schema=bq_schema,
autodetect=(bq_schema is None),
source_format=bigquery.SourceFormat.PARQUET,
time_partitioning=(
bigquery.TimePartitioning(field=partition_by) if partition_by else None
),
write_disposition="WRITE_EMPTY",
)
job = bigquery.Client(project_id).load_table_from_uri(
source_uris=dask_tmp_path,
destination=table_ref,
job_config=job_config,
)
try:
return job.result()
except ClientError:
logging.error(f"Load job failed with the following errors: {job.errors}")
raise
def query_to_dask_df(
query: str, project_id: str = None, chunksize: int = None, tmp_path: str = None
) -> dd.DataFrame:
"""Read BigQuery result into dask dataframe using Parquet in GCS as an intermediary.
GCS export tends to have more balanced sharding and better performance compared to BigQuery
Storage API (used in `dask_df_to_gbq`), but unlike that approach does not allow for special
handling of partitioned data.
"""
if tmp_path is None:
tmp_path = f"gs://model_bigquery_tmp/{uuid4().hex}/*.parquet"
logging.info("Writing intermediate query_to_dask_df parquet files to %s", tmp_path)
query = f"EXPORT DATA OPTIONS(uri='{tmp_path}', format=PARQUET) AS\n{query}"
with bigquery_client(project_id) as bq_client:
job = bq_client.query(query)
job.result() # block until complete
with dask.annotate(retries=3): # Some reads seem to fail transiently - see RAD-1820.
ddf = dd.read_parquet(tmp_path)
if chunksize:
num_rows = int(
job.__dict__["_properties"]["statistics"]["query"]["exportDataStatistics"][
"rowCount"
]
)
ddf = ddf.repartition(npartitions=max(num_rows // chunksize, 1))
return ddf
@ncclementi
Copy link

@bnaul this is great, I'd like to try out this code but I'm missing imports and some information about dependencies.
Do you have a snippet of code/notebook that I can use to be able to run an example?

I'm guessing I need

import dask 
import distributed
import pandas as pd
import pyarrow
from google.cloud import bigquery

It looks like you are also using google-cloud-bigquery-storage

But a couple of things I'm not sure where are they coming from, would you mind pointing out where are these coming from? Are these custom functions?

  • bigquery_client() as in with bigquery_client(project_id)
  • get_temporary_table(bq_client)

Thanks in advance

@bnaul
Copy link
Author

bnaul commented Jul 26, 2021

@ncclementi those look right, the bigquery client helper is

@contextmanager
def bigquery_client(project_id=_DEFAULT_BQ_PROJECT, with_storage_api=False):
    # Ignore google auth credentials warning
    warnings.filterwarnings(
        "ignore", "Your application has authenticated using end user credentials"
    )

    bq_storage_client = None
    bq_client = bigquery.Client(project_id)
    try:
        if with_storage_api:
            bq_storage_client = bigquery_storage.BigQueryReadClient(
                credentials=bq_client._credentials
            )
            yield bq_client, bq_storage_client
        else:
            yield bq_client
    finally:
        bq_client.close()

and get_temporary_table is just

def get_temporary_table(bq_client):
    return bq_client.get_table("tmp.{uuid.uuid4().hex}")

where tmp is a dataset we use with a 1-day expiration policy.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment