Skip to content

Instantly share code, notes, and snippets.

@mryssng
Last active January 23, 2021 16:31
Show Gist options
  • Save mryssng/93b3c7ed431d00ee91930d43a5aa0535 to your computer and use it in GitHub Desktop.
Save mryssng/93b3c7ed431d00ee91930d43a5aa0535 to your computer and use it in GitHub Desktop.
Scatter Plot Matrix by pandas
# coding: utf-8
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
if __name__ == '__main__':
plt.close('all')
# Load Dataset
dataset_file = 'dataset.csv'
df = pd.read_csv(dataset_file)
header = df.columns.values.tolist()
dataset = df.values
data_name = dataset[:, 0]
dataset = list([list(x) for x in dataset])
# Remove First Colmun
dataset = np.delete(dataset, 0, 1)
header = np.delete(header, 0)
a_min = df['a'].min()
a_max = df['a'].max()
bins = np.linspace(a_min, a_max, 10) # Divide between a_min and vcc_max into 10
df['a_range'] = pd.cut(df['a'], bins=bins)
sns.set_style('white')
sns.set()
g = sns.pairplot(df, hue='a_range', hue_order=df['a_range'].cat.categories, palette='YlGnBu', aspect=1.5)
# Change legends location
handles = g._legend_data.values()
labels = g._legend_data.keys()
g._legend.remove()
g.fig.legend(handles=handles, labels=labels, loc='upper left', ncol=3)
g.savefig('scatter_plot_matrix.png')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment