Skip to content

Instantly share code, notes, and snippets.

@twang2218
Created September 16, 2023 07:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save twang2218/4ac3b7c75296e650647a69d12efc62b2 to your computer and use it in GitHub Desktop.
Save twang2218/4ac3b7c75296e650647a69d12efc62b2 to your computer and use it in GitHub Desktop.
from joypy import joyplot
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import math
def ridge_plot(data, by, columns, figsize=(12,8), colormap="Spectral", overlap=0.5, alpha=0.9, linecolor='white'):
fig = plt.figure(figsize=figsize)
classes = sorted(data[by].unique())
num_of_classes = len(classes)
sqrt_num_of_plots = int(np.sqrt(num_of_classes))
ncols = sqrt_num_of_plots
nrows = math.ceil(num_of_classes / ncols)
subfigs = fig.subfigures(nrows, ncols, wspace=0.2, hspace=0.2)
for i in range(num_of_classes):
row = i // ncols
col = i % ncols
subfig = subfigs[row][col]
axes = subfig.subplots(num_classes, 1, gridspec_kw={'hspace': -overlap})
# print(f"row: {row}, col: {col}")
if isinstance(columns[i], str):
columns[i] = [columns[i]]
num_of_columns = len(columns[i])
color = sns.color_palette(colormap, num_of_classes*num_of_columns)
# display(color)
try:
joyplot(
data=data,
by=by,
column=columns[i],
ax=axes,
alpha=alpha,
title=','.join(columns[i]),
color=color,
linecolor='white',
)
except Exception as e:
if 'tight_layout' in str(e):
# as expected
pass
else:
print(e)
plt.show()
@twang2218
Copy link
Author

Demo:

# 生成数据集
np.random.seed(42)
# 定义参数
num_classes = 9
samples_per_class = 200
num_features = 30

# 创建数据集
data = []
labels = []

for _ in range(num_classes):
    class_data = np.zeros((samples_per_class, num_features))
    for col in range(num_features):
        random_mean = np.random.uniform(-1, 1)
        class_data[:, col] = np.random.normal(random_mean, 1, samples_per_class)
    label = f"Class_{string.ascii_uppercase[_]}"
    class_labels = np.full(samples_per_class, label)
    data.append(class_data)
    labels.append(class_labels)

# 合并数据
data = np.vstack(data)
labels = np.concatenate(labels)

print(data.shape)  # (1800, 30)
print(labels.shape)  # (1800,)

df = pd.DataFrame(data, columns=[f'feature_{_}' for _ in range(num_features)])
df['label'] = labels
df.sample(5)

Plot:

columns = [f'feature_{_}' for _ in np.random.choice(range(30), 9, replace=False)]
ridge_plot(df, by='label', columns=columns, figsize=(12,12), colormap="viridis_r")

output

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment