Last active
June 12, 2021 19:07
-
-
Save scravy/8c8a1ee4df3d31c46b5558723de042dd to your computer and use it in GitHub Desktop.
Frequencies using Spark and a Pandas-UDF using Hamilton's method with exploding rows
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from typing import Dict | |
import numpy as np | |
import pandas as pd | |
import pyspark | |
import pyspark.sql.functions as F | |
from pyspark.sql import DataFrame, SparkSession | |
# noinspection PyUnresolvedReferences | |
from pyspark.sql.pandas.functions import pandas_udf | |
def fields(from_df: DataFrame, *without: str): | |
yield from (f for f in from_df.schema.fieldNames() if f not in without) | |
def country_split(df: DataFrame) -> DataFrame: | |
countries_df: DataFrame = df \ | |
.where(F.size('countries') == 1) \ | |
.select(F.explode('countries').alias('country'), 'installs') \ | |
.groupby('country') \ | |
.agg(F.sum('installs').alias('installs')) | |
total_installs = countries_df.agg(F.sum('installs').alias('installs')).collect()[0]['installs'] | |
frequencies_df = countries_df.select( | |
'country', | |
(F.col('installs') / F.lit(total_installs)).alias('frequency'), | |
) | |
frequencies = {row['country']: row['frequency'] for row in frequencies_df.collect()} | |
def hamilton(row): | |
if not row[0]: | |
return None | |
installs = row[1] | |
fs: Dict[str, float] = {c: frequencies[c] for c in row[0].split(';')} | |
total = sum(fs.values()) | |
ratios = np.array([value / total for value in fs.values()]) | |
frac, results = np.modf(installs * ratios) | |
remainder = int(installs - results.sum()) | |
ixs = np.argsort(frac)[::-1] | |
results[ixs[0:remainder]] += 1 | |
for country, result in zip(fs.keys(), results): | |
if country == row[2]: | |
return result | |
return -1 | |
@pandas_udf(returnType='int') | |
def allocation(countries: pd.Series, installs: pd.Series, country: pd.Series) -> pd.Series: | |
return pd.concat([countries, installs, country], axis=1).apply(hamilton, axis=1) | |
df2: DataFrame = df.withColumn('_countries', F.when(F.size('countries') > 1, F.array_join('countries', ';'))) | |
df3: DataFrame = df2.select(F.explode_outer('countries').alias('country'), *fields(df2, 'countries')) | |
return df3.select( | |
F | |
.when(F.col('_countries').isNotNull(), allocation('_countries', 'installs', 'country')) | |
.otherwise(F.col('installs')) | |
.alias('installs'), | |
*fields(df3, '_countries', 'installs') | |
) | |
def main(spark: SparkSession): | |
df: DataFrame = spark.createDataFrame([ | |
("a01", ['B', 'A'], 4), | |
("b02", ['A', 'B'], 30), | |
("c03", ['A', 'B', 'C'], 7), | |
("d04", ['A'], 13), | |
("e05", ['B'], 3), | |
("f06", ['B'], 19), | |
("g07", ['B'], 7), | |
("h08", ['C'], 4), | |
("i09", ['C'], 1), | |
], schema=["id", "countries", "installs"]) | |
resolved: DataFrame = country_split(df) | |
resolved.show() | |
total_installs_source = df.agg(F.sum('installs').alias('total')).collect()[0]['total'] | |
total_installs_result = resolved.agg(F.sum('installs').alias('total')).collect()[0]['total'] | |
print(total_installs_source, total_installs_result) | |
assert total_installs_source == total_installs_result | |
if __name__ == '__main__': | |
os.environ['SPARK_LOCAL_HOSTNAME'] = "localhost" | |
builder = pyspark.sql.SparkSession.Builder() | |
builder.master(f"local[2]") | |
builder.config("spark.sql.shuffle.partitions", 2) | |
builder.config("spark.driver.extraJavaOptions", "-Dio.netty.tryReflectionSetAccessible=true") | |
builder.config("spark.executor.extraJavaOptions", "-Dio.netty.tryReflectionSetAccessible=true") | |
with builder.getOrCreate() as session: | |
main(session) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment