Created
July 22, 2020 00:42
-
-
Save jaketf/23447b179ee9df0d385c6d43b444ca12 to your computer and use it in GitHub Desktop.
Using Pandas in an Apache Beam PTransform
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2020 Google Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
# | |
# This software is provided as-is, | |
# without warranty or representation for any use or purpose. | |
# Your use of it is subject to your agreement with Google. | |
from functools import partial | |
from typing import Callable, ClassVar, Dict, List, Tuple, TypeVar, Union | |
import apache_beam as beam | |
import fastrand | |
import pandas as pd | |
KT = TypeVar("KT") | |
@beam.typehints.with_input_types(Union[Dict, Tuple[KT, Dict]]) | |
@beam.typehints.with_output_types(Dict) | |
class PandasOverDicts(beam.PTransform): | |
""" | |
PTransform to allow users to bring logic from the familiar pandas | |
Dataframes API over PCollections of dicts. In situations where the data | |
volume per window (per key) can fit in memory on a single worker. | |
Sharding Input: | |
There 3 Options for this tranform to shard the input Dicts within a | |
Window into individual pandas Dataframes. | |
1. Specify a `split_by` key which will extract the split value for the given | |
key and use this to split up Dataframes. | |
2. Specify a `num_shards` which will split into the specified number of | |
Dataframes by sharding input elements by a random key. | |
3. Use default sharding by 50 random keys | |
Parallelization: | |
Within a window, even a single shard based on the above logic could get | |
large. To avoid creating very large Dataframes we use a GroupIntoBatches | |
to set an upper bound on the number of records that will be part of a single | |
dataframe. This is configurable with `df_max_batch_size` which defaults to | |
1000. | |
Guidance on How to Choose Sharding and `df_max_batch_size` | |
for Popular Use Cases: | |
1. "Pandas logic does not do aggregation" | |
Use the defaults. | |
2. "Pandas logic does aggregation but this can be sharded on a known key" | |
Your requirements dictate that we use a single DataFrame per Window | |
per key. This requires fitting all this data in memory on a single | |
worker. | |
Be warned this is not very scalable for very hot keys and you should | |
consider using expressing your aggregation logic as Beam SQL or a | |
CombineFn. | |
split_by=<your-key>, skip_batching=True | |
3. "My pandas logic specifies aggregation that should operate on all inputs | |
in the window and I do not have a reasonable key to shard on." | |
Your requirements dictate that we use a single DataFrame per Window. | |
Be warned this is not very scalable for large windows and you should | |
consider using expressing your aggregation logic as Beam SQL or a | |
CombineFn. | |
num_shards=1, skip_batching=True | |
Args: | |
pandas_fn: Callable[[pd.DataFrame], pd.DataFrame] Function specifying a | |
pandas DataFrame -> DataFrame transformation. | |
split_by: The key in the input dicts on which to shard. This takes | |
precedence over random sharding methods. | |
num_shards: int Configurable number of random batches. This should not | |
be specified when using `split_by`. | |
skip_sharding: bool skip sharding by split_by or num_shards because | |
input PCollection is already key-value Tuples. | |
skip_batching: bool skip batching elements (per window per key) into | |
dataframes of a fixed size because all these elements are required | |
for the aggregation logic specified in your pandas_fn | |
max_df_rows: int maximum number of elements to be collected in a | |
single DataFrame. Default is 10000. | |
""" | |
DEFAULT_NUM_SHARDS = 100 # type: ClassVar[int] | |
def __init__(self, | |
pandas_fn: Callable[[pd.DataFrame], pd.DataFrame], | |
split_by: KT = None, | |
num_shards: int = None, | |
skip_sharding: bool = False, | |
skip_batching: bool = False, | |
max_df_rows: int = 10000, | |
*args, | |
**kwargs): | |
self.pandas_fn = pandas_fn | |
self.max_df_rows = max_df_rows | |
self.split_by = split_by | |
self.num_shards = num_shards | |
self.skip_sharding = skip_sharding | |
self.skip_batching = skip_batching | |
if self.skip_sharding and any({self.split_by, self.num_shards}): | |
raise ValueError( | |
f"skip_sharding should not be used with split_by or num_shards." | |
f"got skip_sharing {skip_sharding}, split_by: {split_by}, " | |
f"and num_shards: {num_shards}") | |
if self.split_by and (self.num_shards or self.skip_sharding): | |
raise ValueError( | |
"split_by cannot be used with num_shards or skip_sharding") | |
def expand(self, input_or_inputs: beam.pvalue.PCollection): | |
pandas_wrapper = partial(_pandas_over_dict, pandas_fn=self.pandas_fn) | |
_num_shards = self.num_shards if self.num_shards else \ | |
PandasOverDicts.DEFAULT_NUM_SHARDS | |
_random_shard = partial(fastrand.pcg32bounded, _num_shards) | |
if self.skip_sharding: | |
keyed_values = input_or_inputs | |
elif self.split_by: | |
keyed_values = ( | |
input_or_inputs | |
| "Extract key" >> beam.WithKeys(lambda x: x.get(self.split_by)) | |
) | |
else: | |
keyed_values = ( | |
input_or_inputs | |
| "Pair with random key" >> beam.WithKeys(_random_shard) | |
) | |
grouped_values = (keyed_values | |
| "Parallelize Dataframe by Key" >> beam.GroupByKey()) | |
if self.skip_batching: | |
prepared_input = grouped_values | |
else: | |
prepared_input = ( | |
grouped_values | |
| "Avoid Large Dataframes" >> beam.GroupIntoBatches( | |
self.max_df_rows)) | |
return (prepared_input | |
| "Apply Pandas Logic" >> beam.Map(pandas_wrapper) | |
| "Flatten results and drop key" >> beam.FlatMap(lambda x: x)) | |
def _pandas_over_dict( | |
keyed_data: List[Tuple[KT, Dict]], | |
pandas_fn: Callable[[pd.DataFrame], pd.DataFrame] = None) -> List[Dict]: | |
"""Executes a DataFrame to DataFrame transformation over a List of Dicts. | |
Args: | |
keyed_data: List[Tuple[KT, List[Dict]] input data on which to apply | |
pandas_fn | |
pandas_fn: Callable[[pd.DataFrame], pd.DataFrame] transformation logic | |
""" | |
if not pandas_fn: | |
raise ValueError("pandas_fn must not be None") | |
df_in = pd.DataFrame.from_records(keyed_data[0][1]) | |
return pandas_fn(df_in).to_dict("records") | |
# BELOW HERE would be in separate unit test file included for completeness of gist | |
import apache_beam as beam | |
import numpy as np | |
import pandas as pd | |
from apache_beam.testing.test_pipeline import TestPipeline | |
from apache_beam.testing.util import assert_that, equal_to | |
from calculate_margin.transforms.pandas_transforms import PandasOverDicts | |
def foo_pandas_fn(df: pd.DataFrame): | |
df["avg"] = df.groupby('id').transform(np.mean) | |
return df | |
def test_pandas_over_dicts(): | |
test_dicts = [ | |
{"id": 0, "count": 1}, | |
{"id": 0, "count": 3}, | |
{"id": 1, "count": 4}, | |
{"id": 1, "count": 6}, | |
] | |
expected_dicts = [ | |
{"id": 0, "count": 1, "avg": 2}, | |
{"id": 0, "count": 3, "avg": 2}, | |
{"id": 1, "count": 4, "avg": 5}, | |
{"id": 1, "count": 6, "avg": 5}, | |
] | |
with TestPipeline() as p: | |
output = (p | |
| beam.Create(test_dicts) | |
| PandasOverDicts(foo_pandas_fn, split_by="id")) | |
assert_that(output, equal_to(expected_dicts)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment