Last active
November 4, 2024 13:47
-
-
Save aloni636/f47a0132bb11279b73f1e225fc4a42a9 to your computer and use it in GitHub Desktop.
General PySpark utilities I didn't find anywhere else.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# dev environment: https://github.com/jplane/pyspark-devcontainer | |
from pyspark.sql import functions as F, Column | |
import string | |
from typing import Dict, Union, Callable | |
from functools import reduce | |
def interpolate_string(template: str, strict: bool = True, **kwargs: Dict[str, Union[str, Column]]): | |
""" | |
- `template`: string with python string placeholders | |
- `strict`: if False, placeholders with no matching kwarg will be automaticity replaced with column of their name | |
- `kwargs`: placeholder values in from of column names or column objects | |
Examples: | |
- `dataset.withColumn("interpolated", interpolate_string("hello {world}", world=F.col("column"))))` | |
- `("hello {world}, {world}", world="column")` | |
- `("{hello} {world}", strict=False, world="column")` | |
- `("{hello} {{world}}", hello="column")` | |
- `("{hello} {{{world}}}", hello="column", world=F.col("abc"))` | |
""" | |
spark_kwargs = {k: v if isinstance(v, Column) else F.col(v) for k, v in kwargs.items()} | |
spark_template = template.replace("%", "%%") # escape format specifiers | |
# https://stackoverflow.com/questions/14061724/how-can-i-find-all-placeholders-for-str-format-in-a-python-string-using-a-regex/14061832#14061832 | |
placeholder_names = [name for _, name, _, _ in string.Formatter().parse(spark_template) if name is not None] | |
spark_template = spark_template.format_map({k: "%s" for k in placeholder_names}) | |
def handler(placeholder_name): | |
try: | |
col = spark_kwargs[placeholder_name] | |
except KeyError as e: | |
if strict: | |
raise e | |
col = F.col(placeholder_name) | |
return col | |
spark_cols = [handler(name) for name in placeholder_names] | |
return F.format_string(spark_template, *spark_cols) | |
def array_sample(arr: str | Column, n: int = 5): | |
"""Evenly sample from an array; N must be greater then 1; The entire array is returned if N is greater then array size; Priority is given to first and last items.""" | |
arr = arr if isinstance(arr, Column) else F.col(arr) | |
s = F.size(arr) | |
sampled = F.array([F.get(arr, F.floor((s-1)*i/(n-1))) for i in range(n)]) | |
return F.when(s>n, sampled).otherwise(arr) | |
def agg_by(col: Column| str, agg_at: Column| str, agg_f: Callable[[Column], Column] = F.min): | |
"""Get row item that is at the same row as the aggregate result of another column. | |
Parameters | |
---------- | |
col: Column to return as result. | |
agg_at: Column which the aggregation will be executed against. | |
agg_f: Aggregation applied to each group in `agg_at`. Executing UDFs is not supported. | |
Example | |
------- | |
>>> df = spark.createDataFrame([ | |
... ["A",0, "2024-01-07" ], | |
... ["A",1, "2024-01-06" ], | |
... ["A",2, "2024-01-05" ], | |
... ["B",3, "2024-01-04" ], | |
... ["B",4, "2024-01-03" ], | |
... ["B",5, "2024-01-02" ], | |
... ["B",6, "2024-01-01" ], | |
... ], schema=["A", "B", "C"]).repartition(4) | |
>>> df = df.withColumn("C", F.col("C").cast(T.TimestampType())) | |
>>> df.groupBy("A").agg( | |
... F.struct( | |
... F.collect_list("B").alias("B"), | |
... F.collect_list("C").alias("C"), | |
... ).alias("B | C"), | |
... agg_by("B", "C", F.max).alias("B_by_max_C") | |
... ).show(truncate=False) | |
+---+----------------------------------------------------------------+----------+ | |
|A |B | C |B_by_max_C| | |
+---+----------------------------------------------------------------+----------+ | |
|A |{[0, 1, 2], [2024-01-07, 2024-01-06, 2024-01-05]} |0 | | |
|B |{[3, 4, 5, 6], [2024-01-04, 2024-01-03, 2024-01-02, 2024-01-01]}|3 | | |
+---+----------------------------------------------------------------+----------+ | |
""" | |
# F.expr unfortunately prevents the usage of UDFs/pandas_agg_udfs | |
agg_item = agg_f(agg_at)._jc.toString() | |
col = col if isinstance(col, str) else col._jc.toString() | |
agg_at = agg_at if isinstance(agg_at, str) else agg_at._jc.toString() | |
return F.expr( # collect_list does not break rows, i.e it is only nondeterministic across rows | |
f"element_at(collect_list({col}), cast(array_position(collect_list({agg_at}), {agg_item}) as int))" | |
) | |
def sha2_cols(*cols: Column, numBits: int = 256) -> Column: | |
"""Return a robust hash of multiple columns; Based on `F.sha2`.""" | |
return F.sha2(F.concat_ws("", | |
*cols, # used for avoiding hash collisions in F.hash | |
*[F.hash(col).cast("string") for col in cols] # used for distinguishing between empty string and nulls | |
), numBits) | |
def split_by_sequence(col: Column | str, sequence: list[str]) -> Column: | |
""" | |
Splits str sequentially around matches of each given pattern. | |
Example | |
------- | |
>>> df = spark.createDataFrame([["A) hello B) world A) abc C) hello"]], schema=["col1"]) | |
>>> df = df.withColumn("split_by_sequence", split_by_sequence("col1", [r"A\)", r"B\)", r"C\)"])) | |
>>> df.show(truncate=False) | |
+---------------------------------+-----------------------------------+ | |
|col1 |split_by_sequence | | |
+---------------------------------+-----------------------------------+ | |
|A) hello B) world A) abc C) hello|[, hello , world A) abc , hello]| | |
+---------------------------------+-----------------------------------+ | |
""" | |
def merge(acc: Column, pattern: str) -> Column: | |
element_to_split = F.element_at(acc, -1) | |
split = F.split(element_to_split, pattern, 2) | |
return F.when( | |
F.size(split) == 1, | |
acc | |
).otherwise( | |
F.concat( | |
F.slice(acc, 1, F.size(acc) - 1), # drop last element | |
split | |
) | |
) | |
initial = F.array(col) | |
output = reduce(merge, sequence, initial) | |
return output |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment