Skip to content

Instantly share code, notes, and snippets.

@smzn
Created January 22, 2024 06:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save smzn/1c731107952d07683549822e2fc255cf to your computer and use it in GitHub Desktop.
Save smzn/1c731107952d07683549822e2fc255cf to your computer and use it in GitHub Desktop.
2値データ以外での散布図
# Re-importing necessary libraries and reloading the data
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import itertools
# Function to determine if a column is binary
def is_binary(column):
return sorted(column.unique()) == [0, 1]
# Identifying numeric columns in the dataset
numeric_columns = data.select_dtypes(include=['float64', 'int64']).columns
# Filtering out binary columns to keep only non-binary numeric columns
non_binary_numeric_columns = [col for col in numeric_columns if not is_binary(data[col])]
# Creating combinations of non-binary numeric columns for scatter plots
non_binary_combinations = list(itertools.combinations(non_binary_numeric_columns, 2))
# Setting up the plotting grid for these combinations
n_plots = len(non_binary_combinations)
n_cols = 3 # Number of columns per row
n_rows = (n_plots + n_cols - 1) // n_cols # Calculating the required number of rows
fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5 * n_rows)) # Adjusting the figure size
axes = axes.flatten()
# Creating scatter plots for each non-binary combination
for i, (col1, col2) in enumerate(non_binary_combinations):
sns.scatterplot(x=col1, y=col2, data=data, ax=axes[i], alpha=0.5)
axes[i].set_title(f'{col1} vs {col2}')
# Hiding any unused axes
for j in range(i + 1, len(axes)):
axes[j].set_visible(False)
plt.tight_layout()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment