-
-
Save danielchalef/c395dc529f6cf16cbe63d6f8f27c54f5 to your computer and use it in GitHub Desktop.
# PySpark equivalent to `pandas.wide_to_long()` | |
# Credit for melt(): https://stackoverflow.com/questions/41670103/how-to-melt-spark-dataframe | |
# | |
# Usage: https://pandas.pydata.org/pandas-docs/stable/generated/pandas.wide_to_long.html | |
# Note: this was a quick hack and some error checking in the original Pandas version has been stripped out. | |
import re | |
from typing import Iterable, List | |
from pyspark.sql import DataFrame | |
def wide_to_long(df, stubnames, i, j, sep="", suffix=r"\d+"): | |
def is_list(x): | |
return isinstance(x, List) | |
def melt( | |
df: DataFrame, | |
id_vars: Iterable[str], | |
value_vars: Iterable[str], | |
var_name: str = "variable", | |
value_name: str = "value", | |
) -> DataFrame: | |
"""Convert :class:`DataFrame` from wide to long format.""" | |
# Create array<struct<variable: str, value: ...>> | |
_vars_and_vals = F.array( | |
*( | |
F.struct(F.lit(c).alias(var_name), F.col(c).alias(value_name)) | |
for c in value_vars | |
) | |
) | |
# Add to the DataFrame and explode | |
_tmp = df.withColumn("_vars_and_vals", F.explode(_vars_and_vals)) | |
cols = id_vars + [ | |
F.col("_vars_and_vals")[x].alias(x) for x in [var_name, value_name] | |
] | |
return _tmp.select(*cols) | |
def get_var_names(df, stub, sep, suffix): | |
regex = r"^{stub}{sep}{suffix}$".format( | |
stub=re.escape(stub), sep=re.escape(sep), suffix=suffix | |
) | |
pattern = re.compile(regex) | |
return [col for col in df.columns if pattern.match(col)] | |
def melt_stub(df, stub, i, j, value_vars, sep): | |
newdf = melt( | |
df, | |
id_vars=i, | |
value_vars=value_vars, | |
value_name=stub.rstrip(sep), | |
var_name=j, | |
) | |
# newdf[j] = Categorical(newdf[j]) | |
newdf = newdf.withColumn(j, F.regexp_replace(j, stub + sep, "")) | |
# # GH17627 Cast numerics suffixes to int/float | |
# newdf[j] = to_numeric(newdf[j], errors='ignore') | |
return newdf # .set_index(i + [j]) | |
if not is_list(stubnames): | |
stubnames = [stubnames] | |
else: | |
stubnames = list(stubnames) | |
if any(col in stubnames for col in df.columns): | |
raise ValueError("stubname can't be identical to a column name") | |
if not is_list(i): | |
i = [i] | |
else: | |
i = list(i) | |
# if df.select(i).duplicated().any(): | |
# raise ValueError("the id variables need to uniquely identify each row") | |
value_vars = [get_var_names(df, stub, sep, suffix) for stub in stubnames] | |
value_vars_flattened = [e for sublist in value_vars for e in sublist] | |
id_vars = list(set(df.columns).difference(value_vars_flattened)) | |
melted = [melt_stub(df, s, i, j, v, sep) for s, v in zip(stubnames, value_vars)] | |
if len(melted) > 1: | |
melted = melted[0].join(melted[1:], on=i, how="outer") | |
else: | |
melted = melted[0] | |
new = df.select(id_vars).join(melted, on=i, how="left") | |
return new |
import pyspark.sql.functions as F
ah ok thanks.
Unfortunately I get an error (Python 2.7, Windows10):
in wide_to_long(df, stubnames, i, j, sep, suffix)
949
950 if len(melted) > 1:
--> 951 melted = melted[0].join(melted[1:], on=i, how="outer")
952 else:
953 melted = melted[0]
C:\ProgramData\Miniconda2\lib\site-packages\pyspark\sql\dataframe.pyc in join(self, other, on, how)
1050 on = self._jseq([])
1051 assert isinstance(how, basestring), "how should be basestring"
-> 1052 jdf = self._jdf.join(other._jdf, on, how)
1053 return DataFrame(jdf, self.sql_ctx)
1054
AttributeError: 'list' object has no attribute '_jdf'
The problem could be fixed by changing your code:
melted = melted[0].join(melted[1:], on=i, how="outer")
to
melted = reduce(lambda df1, df2: df1.join(df2.drop(j), on=i, how="outer"), melted)
what is F?