Last active
March 8, 2022 10:55
-
-
Save danielchalef/c395dc529f6cf16cbe63d6f8f27c54f5 to your computer and use it in GitHub Desktop.
PySpark equivalent to pandas.wide_to_long()
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# PySpark equivalent to `pandas.wide_to_long()` | |
# Credit for melt(): https://stackoverflow.com/questions/41670103/how-to-melt-spark-dataframe | |
# | |
# Usage: https://pandas.pydata.org/pandas-docs/stable/generated/pandas.wide_to_long.html | |
# Note: this was a quick hack and some error checking in the original Pandas version has been stripped out. | |
import re | |
from typing import Iterable, List | |
from pyspark.sql import DataFrame | |
def wide_to_long(df, stubnames, i, j, sep="", suffix=r"\d+"): | |
def is_list(x): | |
return isinstance(x, List) | |
def melt( | |
df: DataFrame, | |
id_vars: Iterable[str], | |
value_vars: Iterable[str], | |
var_name: str = "variable", | |
value_name: str = "value", | |
) -> DataFrame: | |
"""Convert :class:`DataFrame` from wide to long format.""" | |
# Create array<struct<variable: str, value: ...>> | |
_vars_and_vals = F.array( | |
*( | |
F.struct(F.lit(c).alias(var_name), F.col(c).alias(value_name)) | |
for c in value_vars | |
) | |
) | |
# Add to the DataFrame and explode | |
_tmp = df.withColumn("_vars_and_vals", F.explode(_vars_and_vals)) | |
cols = id_vars + [ | |
F.col("_vars_and_vals")[x].alias(x) for x in [var_name, value_name] | |
] | |
return _tmp.select(*cols) | |
def get_var_names(df, stub, sep, suffix): | |
regex = r"^{stub}{sep}{suffix}$".format( | |
stub=re.escape(stub), sep=re.escape(sep), suffix=suffix | |
) | |
pattern = re.compile(regex) | |
return [col for col in df.columns if pattern.match(col)] | |
def melt_stub(df, stub, i, j, value_vars, sep): | |
newdf = melt( | |
df, | |
id_vars=i, | |
value_vars=value_vars, | |
value_name=stub.rstrip(sep), | |
var_name=j, | |
) | |
# newdf[j] = Categorical(newdf[j]) | |
newdf = newdf.withColumn(j, F.regexp_replace(j, stub + sep, "")) | |
# # GH17627 Cast numerics suffixes to int/float | |
# newdf[j] = to_numeric(newdf[j], errors='ignore') | |
return newdf # .set_index(i + [j]) | |
if not is_list(stubnames): | |
stubnames = [stubnames] | |
else: | |
stubnames = list(stubnames) | |
if any(col in stubnames for col in df.columns): | |
raise ValueError("stubname can't be identical to a column name") | |
if not is_list(i): | |
i = [i] | |
else: | |
i = list(i) | |
# if df.select(i).duplicated().any(): | |
# raise ValueError("the id variables need to uniquely identify each row") | |
value_vars = [get_var_names(df, stub, sep, suffix) for stub in stubnames] | |
value_vars_flattened = [e for sublist in value_vars for e in sublist] | |
id_vars = list(set(df.columns).difference(value_vars_flattened)) | |
melted = [melt_stub(df, s, i, j, v, sep) for s, v in zip(stubnames, value_vars)] | |
if len(melted) > 1: | |
melted = melted[0].join(melted[1:], on=i, how="outer") | |
else: | |
melted = melted[0] | |
new = df.select(id_vars).join(melted, on=i, how="left") | |
return new |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The problem could be fixed by changing your code:
melted = melted[0].join(melted[1:], on=i, how="outer")
to
melted = reduce(lambda df1, df2: df1.join(df2.drop(j), on=i, how="outer"), melted)