Skip to content

Instantly share code, notes, and snippets.

@aloni636
Last active July 10, 2024 22:14
Show Gist options
  • Save aloni636/f47a0132bb11279b73f1e225fc4a42a9 to your computer and use it in GitHub Desktop.
Save aloni636/f47a0132bb11279b73f1e225fc4a42a9 to your computer and use it in GitHub Desktop.
General PySpark utilities I didn't find anywhere else.
# dev environment: https://github.com/jplane/pyspark-devcontainer
from pyspark.sql import functions as F, Column
import string
from typing import Dict, Union, Callable
def interpolate_string(template: str, strict: bool = True, **kwargs: Dict[str, Union[str, Column]]):
"""
- `template`: string with python string placeholders
- `strict`: if False, placeholders with no matching kwarg will be automaticity replaced with column of their name
- `kwargs`: placeholder values in from of column names or column objects
Examples:
- `dataset.withColumn("interpolated", interpolate_string("hello {world}", world=F.col("column"))))`
- `("hello {world}, {world}", world="column")`
- `("{hello} {world}", strict=False, world="column")`
- `("{hello} {{world}}", hello="column")`
- `("{hello} {{{world}}}", hello="column", world=F.col("abc"))`
"""
spark_kwargs = {k: v if isinstance(v, Column) else F.col(v) for k, v in kwargs.items()}
spark_template = template.replace("%", "%%") # escape format specifiers
# https://stackoverflow.com/questions/14061724/how-can-i-find-all-placeholders-for-str-format-in-a-python-string-using-a-regex/14061832#14061832
placeholder_names = [name for _, name, _, _ in string.Formatter().parse(spark_template) if name is not None]
spark_template = spark_template.format_map({k: "%s" for k in placeholder_names})
def handler(placeholder_name):
try:
col = spark_kwargs[placeholder_name]
except KeyError as e:
if strict:
raise e
col = F.col(placeholder_name)
return col
spark_cols = [handler(name) for name in placeholder_names]
return F.format_string(spark_template, *spark_cols)
def array_sample(arr: str | Column, n: int = 5):
"""Evenly sample from an array; N must be greater then 1; The entire array is returned if N is greater then array size; Priority is given to first and last items."""
arr = arr if isinstance(arr, Column) else F.col(arr)
s = F.size(arr)
sampled = F.array([F.get(arr, F.floor((s-1)*i/(n-1))) for i in range(n)])
return F.when(s>n, sampled).otherwise(arr)
def agg_by(col: Column| str, agg_at: Column| str, agg_f: Callable[[Column], Column] = F.min):
"""Get row item that is at the same row as the aggregate result of another column.
Parameters
----------
col: Column to return as result.
agg_at: Column which the aggregation will be executed against.
agg_f: Aggregation applied to each group in `agg_at`. Executing UDFs is not supported.
Example
-------
>>> df = spark.createDataFrame([
... ["A",0, "2024-01-07" ],
... ["A",1, "2024-01-06" ],
... ["A",2, "2024-01-05" ],
... ["B",3, "2024-01-04" ],
... ["B",4, "2024-01-03" ],
... ["B",5, "2024-01-02" ],
... ["B",6, "2024-01-01" ],
... ], schema=["A", "B", "C"]).repartition(4)
>>> df = df.withColumn("C", F.col("C").cast(T.TimestampType()))
>>> df.groupBy("A").agg(
... F.struct(
... F.collect_list("B").alias("B"),
... F.collect_list("C").alias("C"),
... ).alias("B | C"),
... agg_by("B", "C", F.max).alias("B_by_max_C")
... ).show(truncate=False)
+---+----------------------------------------------------------------+----------+
|A |B | C |B_by_max_C|
+---+----------------------------------------------------------------+----------+
|A |{[0, 1, 2], [2024-01-07, 2024-01-06, 2024-01-05]} |0 |
|B |{[3, 4, 5, 6], [2024-01-04, 2024-01-03, 2024-01-02, 2024-01-01]}|3 |
+---+----------------------------------------------------------------+----------+
"""
# F.expr unfortunately prevents the usage of UDFs/pandas_agg_udfs
agg_item = agg_f(agg_at)._jc.toString()
col = col if isinstance(col, str) else col._jc.toString()
agg_at = agg_at if isinstance(agg_at, str) else agg_at._jc.toString()
return F.expr( # collect_list does not break rows, i.e it is only nondeterministic across rows
f"element_at(collect_list({col}), cast(array_position(collect_list({agg_at}), {agg_item}) as int))"
)
def sha2_cols(*cols: Column, numBits: int = 256) -> Column:
"""Return a robust hash of multiple columns; Based on `F.sha2`."""
return F.sha2(F.concat_ws("",
*cols, # used for avoiding hash collisions in F.hash
*[F.hash(col).cast("string") for col in cols] # used for distinguishing between empty string and nulls
), numBits)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment