Skip to content

Instantly share code, notes, and snippets.

@scravy
Last active June 12, 2021 19:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save scravy/8c8a1ee4df3d31c46b5558723de042dd to your computer and use it in GitHub Desktop.
Save scravy/8c8a1ee4df3d31c46b5558723de042dd to your computer and use it in GitHub Desktop.
Frequencies using Spark and a Pandas-UDF using Hamilton's method with exploding rows
import os
from typing import Dict
import numpy as np
import pandas as pd
import pyspark
import pyspark.sql.functions as F
from pyspark.sql import DataFrame, SparkSession
# noinspection PyUnresolvedReferences
from pyspark.sql.pandas.functions import pandas_udf
def fields(from_df: DataFrame, *without: str):
yield from (f for f in from_df.schema.fieldNames() if f not in without)
def country_split(df: DataFrame) -> DataFrame:
countries_df: DataFrame = df \
.where(F.size('countries') == 1) \
.select(F.explode('countries').alias('country'), 'installs') \
.groupby('country') \
.agg(F.sum('installs').alias('installs'))
total_installs = countries_df.agg(F.sum('installs').alias('installs')).collect()[0]['installs']
frequencies_df = countries_df.select(
'country',
(F.col('installs') / F.lit(total_installs)).alias('frequency'),
)
frequencies = {row['country']: row['frequency'] for row in frequencies_df.collect()}
def hamilton(row):
if not row[0]:
return None
installs = row[1]
fs: Dict[str, float] = {c: frequencies[c] for c in row[0].split(';')}
total = sum(fs.values())
ratios = np.array([value / total for value in fs.values()])
frac, results = np.modf(installs * ratios)
remainder = int(installs - results.sum())
ixs = np.argsort(frac)[::-1]
results[ixs[0:remainder]] += 1
for country, result in zip(fs.keys(), results):
if country == row[2]:
return result
return -1
@pandas_udf(returnType='int')
def allocation(countries: pd.Series, installs: pd.Series, country: pd.Series) -> pd.Series:
return pd.concat([countries, installs, country], axis=1).apply(hamilton, axis=1)
df2: DataFrame = df.withColumn('_countries', F.when(F.size('countries') > 1, F.array_join('countries', ';')))
df3: DataFrame = df2.select(F.explode_outer('countries').alias('country'), *fields(df2, 'countries'))
return df3.select(
F
.when(F.col('_countries').isNotNull(), allocation('_countries', 'installs', 'country'))
.otherwise(F.col('installs'))
.alias('installs'),
*fields(df3, '_countries', 'installs')
)
def main(spark: SparkSession):
df: DataFrame = spark.createDataFrame([
("a01", ['B', 'A'], 4),
("b02", ['A', 'B'], 30),
("c03", ['A', 'B', 'C'], 7),
("d04", ['A'], 13),
("e05", ['B'], 3),
("f06", ['B'], 19),
("g07", ['B'], 7),
("h08", ['C'], 4),
("i09", ['C'], 1),
], schema=["id", "countries", "installs"])
resolved: DataFrame = country_split(df)
resolved.show()
total_installs_source = df.agg(F.sum('installs').alias('total')).collect()[0]['total']
total_installs_result = resolved.agg(F.sum('installs').alias('total')).collect()[0]['total']
print(total_installs_source, total_installs_result)
assert total_installs_source == total_installs_result
if __name__ == '__main__':
os.environ['SPARK_LOCAL_HOSTNAME'] = "localhost"
builder = pyspark.sql.SparkSession.Builder()
builder.master(f"local[2]")
builder.config("spark.sql.shuffle.partitions", 2)
builder.config("spark.driver.extraJavaOptions", "-Dio.netty.tryReflectionSetAccessible=true")
builder.config("spark.executor.extraJavaOptions", "-Dio.netty.tryReflectionSetAccessible=true")
with builder.getOrCreate() as session:
main(session)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment