Skip to content

Instantly share code, notes, and snippets.

@jaketf
Last active June 9, 2020 03:18
Show Gist options
  • Save jaketf/5a3820b91552aa4fb24aaa95388fb5c7 to your computer and use it in GitHub Desktop.
Save jaketf/5a3820b91552aa4fb24aaa95388fb5c7 to your computer and use it in GitHub Desktop.
Apache Beam Python Example: Side Input look up with cache
# Copyright 2020 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import argparse
import logging
from abc import abstractmethod
from cachetools import Cache, LRUCache, LFUCache
from typing import Any, Dict, Iterable, Union
import apache_beam as beam
from apache_beam.pvalue import AsList
from apache_beam.io.gcp.internal.clients import bigquery
from apache_beam.options.pipeline_options import (GoogleCloudOptions,
PipelineOptions,
SetupOptions)
"""
This module demonstrates use of BigQuery side inputs for look up tables.
Example Use:
python3 main.py --project=<your-gcp-project> # used for BigQuery temp dataset for side input
"""
# TODO This reusable interface should live in it's own module and use TypeVars for key type and value type.
@beam.typehints.with_input_types(Any, Iterable[Union[Dict[str, Any],
bigquery.TableRow]])
class LookupWithCacheLargeSideInputFn(beam.DoFn):
"""Base class for DoFn's that use BigQuery side inputs.
Using (BigQuery) side inputs can be tricky but often follow the same pattern.
Let's DRY up our lessons learned for BigQuery side inputs in this base
class which we can reuse by subclassing for various BigQuery side input
use cases.
This class can be used with AsList or AsIter side inputs.
sub-classes may choose to override the caching algorithm based on the use
case by setting the
lookup_cache property to an instance of any cachetools.Cache.
"""
# Subclasses may choose to override this.
lookup_cache: Cache = LRUCache(maxsize=1024)
@staticmethod
@abstractmethod
def extract_lookup_key(
row: Union[Dict[str, Any], bigquery.TableRow]):
"""Sub-classes should override this method to define logic for extracting
the lookup key from a side input element."""
raise NotImplementedError
@staticmethod
@abstractmethod
def extract_lookup_value(
self, row: Union[Dict[str, Any], bigquery.TableRow]):
"""Sub-classes should override this method to define logic for extracting
the lookup value from a side input element."""
raise NotImplementedError
@staticmethod
def get_by_name(row: Union[Dict, bigquery.TableRow], key):
"""Lookup a value by name in a dict or bigquery.TableRow.
this provides safety across runners in case one chooses to provide the
list of side input elements as dict or bigquery.TableRow.
"""
if isinstance(row, bigquery.TableRow):
raise NotImplementedError
if isinstance(row, dict):
return row[key]
@abstractmethod
def process(self, element: str, lookup_table: Iterable[Union[Dict[str, Any], bigquery.TableRow]]):
raise NotImplementedError
def lookup(self, lookup_table: Iterable[Union[Dict[str, Any],
bigquery.TableRow]],
key):
"""look up the side input value in the cache or loop over side input
values to search for the key."""
if self.lookup_cache.get(
key): # we've already seen this lookup before.
return self.lookup_cache[key] # return from cache.
for row in lookup_table: # loop through the side input iterable
if self.extract_lookup_key(row) == key:
# add the lookup to the cache to speed up future lookups for
# this key.
self.lookup_cache[key] = \
self.extract_lookup_value(row)
# don't parse more rows once we've found our lookup value)
# exit eagerly.
return self.lookup_cache[key]
@beam.typehints.with_output_types(Dict[str, int])
class StateLandAreaBQWithCacheLargeSideInputFn(LookupWithCacheLargeSideInputFn):
"""Example sub-class of UseBQSideInputFn that looks up state land area based
on a side input of the BigQuery public table.
This is for example purposes only.
This example is assumes a side input from the Public BigQuery table
`bigquery-public-data:utility_us.us_states_area` described here:
https://www.kaggle.com/bigquery/utility-us
"""
# We happen to know the number of states apriori so we can cache accordingly.
lookup_cache = LFUCache(maxsize=50)
@staticmethod
def extract_lookup_key(
row: Union[Dict[str, Any], bigquery.TableRow]) -> str:
key_field = 'state_abbreviation'
return LookupWithCacheLargeSideInputFn.get_by_name(row, key_field)
@staticmethod
def extract_lookup_value(
row: Union[Dict[str, Any], bigquery.TableRow]) -> int:
value_field = 'area_land_meters'
return LookupWithCacheLargeSideInputFn.get_by_name(row, value_field)
def process(self, element: str, lookup_table: Iterable[bigquery.TableRow]):
area = self.lookup(lookup_table, element)
# This logging statement is for example purposes only.
# It is an anti-pattern to log per element.
logging.info(f'input state abbreviation element: {element} '
f'lookup area: {area}')
return [{'state_abbr': element, 'land_area': area}]
def run(argv=None):
"""Runs an example pipeline demonstrating using a sub-class of our new
SiedInputLookupFn."""
args_parser = argparse.ArgumentParser()
args_parser.add_argument('--project',
type=str,
help="Google Cloud project id",
required=True)
known, extra = args_parser.parse_known_args(argv)
options = PipelineOptions(flags=extra)
options.view_as(GoogleCloudOptions).project = known.project
options.view_as(SetupOptions).save_main_session = True
p = beam.Pipeline(options=options)
sample_state_abbreviations = ['NJ', 'MA', 'OR']
lookup_source = (
p
| beam.io.Read(
beam.io.BigQuerySource(
'bigquery-public-data:utility_us.us_states_area',
)))
(p
| 'Create sample states' >> beam.Create(sample_state_abbreviations)
| 'Log Lookups' >> beam.ParDo(StateLandAreaBQWithCacheLargeSideInputFn(),
AsList(lookup_source)))
p.run().wait_until_finish()
if __name__ == '__main__':
logging.getLogger().setLevel(level=logging.INFO)
run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment