Skip to content

Instantly share code, notes, and snippets.

@cemoody
Created August 31, 2023 15:30
Show Gist options
  • Save cemoody/9fc23aae8628d5f77e33dd9b84869c39 to your computer and use it in GitHub Desktop.
Save cemoody/9fc23aae8628d5f77e33dd9b84869c39 to your computer and use it in GitHub Desktop.
"""Wrapper around BigQuery call."""
from __future__ import annotations
from typing import Any, Iterable
import logging
from google.cloud import bigquery_storage
from google.cloud.bigquery_storage_v1 import exceptions as bqstorage_exceptions
from google.cloud.bigquery_storage_v1 import types, writer
from google.protobuf import descriptor_pb2
from google.protobuf.descriptor import Descriptor
from loguru import logger
class DefaultStreamManager: # pragma: no cover
"""Manage access to the _default stream write streams."""
def __init__(
self,
table_path: str,
message_protobuf_descriptor: Descriptor,
bigquery_storage_write_client: bigquery_storage.BigQueryWriteClient,
):
"""Init."""
self.stream_name = f"{table_path}/_default"
self.message_protobuf_descriptor = message_protobuf_descriptor
self.write_client = bigquery_storage_write_client
self.append_rows_stream = None
def _init_stream(self):
"""Init the underlying stream manager."""
# Create a template with fields needed for the first request.
request_template = types.AppendRowsRequest()
# The initial request must contain the stream name.
request_template.write_stream = self.stream_name
# So that BigQuery knows how to parse the serialized_rows, generate a
# protocol buffer representation of our message descriptor.
proto_schema = types.ProtoSchema()
proto_descriptor = descriptor_pb2.DescriptorProto() # pylint: disable=no-member
self.message_protobuf_descriptor.CopyToProto(proto_descriptor)
proto_schema.proto_descriptor = proto_descriptor
proto_data = types.AppendRowsRequest.ProtoData()
proto_data.writer_schema = proto_schema
request_template.proto_rows = proto_data
# Create an AppendRowsStream using the request template created above.
self.append_rows_stream = writer.AppendRowsStream(
self.write_client, request_template
)
def send_appendrowsrequest(
self, request: types.AppendRowsRequest
) -> writer.AppendRowsFuture:
"""Send request to the stream manager. Init the stream manager if needed."""
try:
if self.append_rows_stream is None:
self._init_stream()
return self.append_rows_stream.send(request)
except bqstorage_exceptions.StreamClosedError:
# the stream needs to be reinitialized
self.append_rows_stream.close()
self.append_rows_stream = None
raise
# Use as a context manager
def __enter__(self) -> DefaultStreamManager:
"""Enter the context manager. Return the stream name."""
self._init_stream()
return self
def __exit__(self, exc_type, exc_value, traceback):
"""Exit the context manager : close the stream."""
if self.append_rows_stream is not None:
# Shutdown background threads and close the streaming connection.
self.append_rows_stream.close()
class BigqueryWriteManager:
"""Encapsulation for bigquery client."""
def __init__(
self,
project_id: str,
dataset_id: str,
table_id: str,
bigquery_storage_write_client: bigquery_storage.BigQueryWriteClient,
pb2_descriptor: Descriptor,
): # pragma: no cover
"""Create a BigQueryManager."""
self.bigquery_storage_write_client = bigquery_storage_write_client
self.table_path = self.bigquery_storage_write_client.table_path(
project_id, dataset_id, table_id
)
self.pb2_descriptor = pb2_descriptor
def write_rows(self, pb_rows: Iterable[Any]) -> None:
"""Write data rows."""
logger.info(f"Writing {len(pb_rows)} rows to {self.table_path}")
with DefaultStreamManager(
self.table_path, self.pb2_descriptor, self.bigquery_storage_write_client
) as target_stream_manager:
logger.info(f"Inited stream manager for {self.table_path}")
proto_rows = types.ProtoRows()
# Create a batch of row data by appending proto2 serialized bytes to the
# serialized_rows repeated field.
for row in pb_rows:
proto_rows.serialized_rows.append(row.SerializeToString())
logger.info(f"Created proto rows for {self.table_path}")
# Create an append row request containing the rows
request = types.AppendRowsRequest()
proto_data = types.AppendRowsRequest.ProtoData()
proto_data.rows = proto_rows
request.proto_rows = proto_data
logger.info(f"Sending append rows request for {self.table_path}")
future = target_stream_manager.send_appendrowsrequest(request)
# Wait for the append row requests to finish.
logger.info(f"Waiting for append rows request for {self.table_path}")
future.result()
logger.info(f"Done writing {len(pb_rows)} rows to {self.table_path}")
if __name__ == "__main__":
from gumtap_protobuf import image_vector_v2
def create_image_vector(i):
data = {
"image_url": f"https://www.amazon.com/product123/image{i}.jpg",
"model_type": "resnet50",
"created_at": "2023-07-08T10:16:30Z",
}
for i in range(768):
data[f"dim_{i:05d}"] = i * 1.0
return image_vector_v2(**data)
project_id = "automatic-asset-359722"
dataset_id = "gumlog"
table_id = "image_vector_v2"
manager = BigqueryWriteManager(
project_id,
dataset_id,
table_id,
bigquery_storage.BigQueryWriteClient(),
image_vector_v2.DESCRIPTOR,
)
image_vectors = [create_image_vector(i) for i in range(100)]
manager.write_rows(image_vectors)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment