Skip to content

Instantly share code, notes, and snippets.

@goraj
Last active May 31, 2023 11:09
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save goraj/aafdbfb534f4df77e4d04f3575cd1126 to your computer and use it in GitHub Desktop.
Save goraj/aafdbfb534f4df77e4d04f3575cd1126 to your computer and use it in GitHub Desktop.
explode_multi_column.py
from typing import List
import pandas as pd
def explode_multi_column(
df: pd.DataFrame, index_cols: List[str], sep: str
):
"""
In [13]:
...: df = pd.DataFrame(
...: dict(
...: site=[
...: "Dagger Complex\nLucius D. Clay Kaserne\nRobinson Barracks",
...: "Torii Station\nYokota Air Base",
...: ],
...: branch=[
...: "Army\nMarine Corps\nAir Force",
...: "Army\nAir Force"
...: ],
...: country=["Germany", "Japan"],
...: )
...: )
In [14]: df
Out[14]:
site branch country
0 Dagger Complex\nLucius D. Clay Kaserne\nRobins... Army\nMarine Corps\nAir Force Germany
1 Torii Station\nYokota Air Base Army\nAir Force Japan
In [15]: explode_multi_column(
...: df, index_cols=["country"], sep="\n"
...: )
Out[15]:
country site branch
0 Germany Dagger Complex Army
1 Germany Lucius D. Clay Kaserne Marine Corps
2 Germany Robinson Barracks Air Force
3 Japan Torii Station Army
4 Japan Yokota Air Base Air Force
"""
# print(df.columns)
split_columns = [
column
for column in df.columns
if column not in index_cols
]
# print(split_columns)
# print(index_cols)
__df = df.set_index(index_cols)
n_levels = len(index_cols)
level = f"level_{n_levels}"
dataframe = pd.concat(
[
__df[column]
.str.split(sep, expand=True)
.stack()
.reset_index(name=column)
.drop(level, axis=1)
for column in split_columns
],
axis=1,
)
return dataframe.loc[:, ~dataframe.columns.duplicated()]
df = pd.DataFrame(
dict(
site=[
"Dagger Complex\nLucius D. Clay Kaserne\nRobinson Barracks",
"Torii Station\nYokota Air Base",
],
branch=[
"Army\nMarine Corps\nAir Force",
"Army\nAir Force"
],
country=["Germany", "Japan"],
)
)
print("original dataframe")
print(df)
print("result")
print(
explode_multi_column(
df, index_cols=["country"], sep="\n"
)
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment