Skip to content

Instantly share code, notes, and snippets.

@danielchalef
Last active March 8, 2022 10:55
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save danielchalef/c395dc529f6cf16cbe63d6f8f27c54f5 to your computer and use it in GitHub Desktop.
Save danielchalef/c395dc529f6cf16cbe63d6f8f27c54f5 to your computer and use it in GitHub Desktop.
PySpark equivalent to pandas.wide_to_long()
# PySpark equivalent to `pandas.wide_to_long()`
# Credit for melt(): https://stackoverflow.com/questions/41670103/how-to-melt-spark-dataframe
#
# Usage: https://pandas.pydata.org/pandas-docs/stable/generated/pandas.wide_to_long.html
# Note: this was a quick hack and some error checking in the original Pandas version has been stripped out.
import re
from typing import Iterable, List
from pyspark.sql import DataFrame
def wide_to_long(df, stubnames, i, j, sep="", suffix=r"\d+"):
def is_list(x):
return isinstance(x, List)
def melt(
df: DataFrame,
id_vars: Iterable[str],
value_vars: Iterable[str],
var_name: str = "variable",
value_name: str = "value",
) -> DataFrame:
"""Convert :class:`DataFrame` from wide to long format."""
# Create array<struct<variable: str, value: ...>>
_vars_and_vals = F.array(
*(
F.struct(F.lit(c).alias(var_name), F.col(c).alias(value_name))
for c in value_vars
)
)
# Add to the DataFrame and explode
_tmp = df.withColumn("_vars_and_vals", F.explode(_vars_and_vals))
cols = id_vars + [
F.col("_vars_and_vals")[x].alias(x) for x in [var_name, value_name]
]
return _tmp.select(*cols)
def get_var_names(df, stub, sep, suffix):
regex = r"^{stub}{sep}{suffix}$".format(
stub=re.escape(stub), sep=re.escape(sep), suffix=suffix
)
pattern = re.compile(regex)
return [col for col in df.columns if pattern.match(col)]
def melt_stub(df, stub, i, j, value_vars, sep):
newdf = melt(
df,
id_vars=i,
value_vars=value_vars,
value_name=stub.rstrip(sep),
var_name=j,
)
# newdf[j] = Categorical(newdf[j])
newdf = newdf.withColumn(j, F.regexp_replace(j, stub + sep, ""))
# # GH17627 Cast numerics suffixes to int/float
# newdf[j] = to_numeric(newdf[j], errors='ignore')
return newdf # .set_index(i + [j])
if not is_list(stubnames):
stubnames = [stubnames]
else:
stubnames = list(stubnames)
if any(col in stubnames for col in df.columns):
raise ValueError("stubname can't be identical to a column name")
if not is_list(i):
i = [i]
else:
i = list(i)
# if df.select(i).duplicated().any():
# raise ValueError("the id variables need to uniquely identify each row")
value_vars = [get_var_names(df, stub, sep, suffix) for stub in stubnames]
value_vars_flattened = [e for sublist in value_vars for e in sublist]
id_vars = list(set(df.columns).difference(value_vars_flattened))
melted = [melt_stub(df, s, i, j, v, sep) for s, v in zip(stubnames, value_vars)]
if len(melted) > 1:
melted = melted[0].join(melted[1:], on=i, how="outer")
else:
melted = melted[0]
new = df.select(id_vars).join(melted, on=i, how="left")
return new
@MarcelBeining
Copy link

what is F?

@danielchalef
Copy link
Author

danielchalef commented Mar 2, 2021

import pyspark.sql.functions as F

@MarcelBeining
Copy link

ah ok thanks.
Unfortunately I get an error (Python 2.7, Windows10):
in wide_to_long(df, stubnames, i, j, sep, suffix)
949
950 if len(melted) > 1:
--> 951 melted = melted[0].join(melted[1:], on=i, how="outer")
952 else:
953 melted = melted[0]

C:\ProgramData\Miniconda2\lib\site-packages\pyspark\sql\dataframe.pyc in join(self, other, on, how)
1050 on = self._jseq([])
1051 assert isinstance(how, basestring), "how should be basestring"
-> 1052 jdf = self._jdf.join(other._jdf, on, how)
1053 return DataFrame(jdf, self.sql_ctx)
1054

AttributeError: 'list' object has no attribute '_jdf'

@MarcelBeining
Copy link

The problem could be fixed by changing your code:
melted = melted[0].join(melted[1:], on=i, how="outer")
to
melted = reduce(lambda df1, df2: df1.join(df2.drop(j), on=i, how="outer"), melted)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment