Skip to content

Instantly share code, notes, and snippets.

@jeanbaptisteb
Last active June 16, 2023 07:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jeanbaptisteb/b89c5d0d0743b40ddd4d1a5f43412964 to your computer and use it in GitHub Desktop.
Save jeanbaptisteb/b89c5d0d0743b40ddd4d1a5f43412964 to your computer and use it in GitHub Desktop.
function to plot static representation of ACM with the output from the PRINCE library
from adjustText import adjust_text
import matplotlib.pyplot as plt
import seaborn as sns
def plot_acm(model, data, x=0, y=1, dims = (11, 8),
keep_prefix=True,
adjust_labels=True):
'''
Parameters
----------
model : prince.mca.MCA
en: MCA model (fitted) from the PRINCE library.
fr: Modèle d'ACM (déjà fitté) de la librairie PRINCE.
data : pandas.core.frame.DataFrame
en: Pandas dataframe containing data.
fr: DataFrame pandas contenant les données.
x : int, optionel
en: Identifier of the component to plot horizontally. By default, 0.
fr: Identifiant de l'axe factoriel qu'on souhaite représenter horizontalement. 0 par défaut.
y : int, optionel
en: Identifier of the component to plot vertically. By default, 1.
fr: Identifiant de l'axe factoriel qu'on souhaite représenter verticalement. 1 par défaut.
dims : tuple, optional
en: Plot size. (11, 8) by default.
fr: Taille du graphique. (11, 8) par défaut.
keep_prefix : boolean, optional
en: To prepend (or not) the name of the variable to the value (e.g. if set to True, will display something like "COLOR_purple",
otherwise will display "purple").
fr: Indique si le texte affiché à côté de chaque point projeté contient un préfixe contenant le nom de la variable.
Si défini à "False", seul l'intitulé de la modalité apparait.
adjust_text : boolean, optional
en: Preventing text from overlapping if set to True.
fr: Si défini à "True", réorganise automatiquement les labels de chaque point de manière à ce qu'ils ne se chevauchent pas.
Returns
-------
plot : matplotlib.axes._subplots.AxesSubplot
Graphique matplotlib.
'''
variables = data.columns
# variance = model.eigenvalues_summary
variance = model.percentage_of_variance_
coord = model.column_coordinates(data)
#coord contient les coordonnées de chaque modalité sur chaque axe factoriel
#on ajoute une colonne "variable", qui permet de faire le lien entre la modalité et la variable à laquelle à correspond
coord["variable"] = ""
new_index = []
for i in range(0, len(coord)):
row = coord.iloc[i]
var_name = row.name
for variable in variables:
values = data[variable].dropna().unique()
for value in values:
comb= "_".join([variable, value])
if comb == var_name:
coord.iloc[i, coord.columns.get_loc("variable")] = variable
new_index.append(coord.iloc[i].name.replace(variable+"_", ""))
if keep_prefix == False:
coord.index = new_index
#On représente ensuite les modalités sous forme d'un nuage de points
sns.set_style("whitegrid")
fig, ax = plt.subplots(figsize=dims)
plot = sns.scatterplot(x=coord[x],
y=coord[y],
hue=coord["variable"],
ax=ax,
)
#on modifie l'intitulé des axex
ax.set_xlabel(f"Dimension {str(x)} ({round(variance[x],1)} %)" )
ax.set_ylabel(f"Dimension {str(y)} ({round(variance[y],1)} %)" )
#on ajoute des lignes pointillées aux origines, pour faciliter l'analyse
plot.axhline(y = 0, color = 'black', linestyle = '--', linewidth = 0.5)
plot.axvline(x = 0, color = 'black', linestyle = '--', linewidth = 0.5)
#on ajoute l'intitulé des points
txts = []
for line in range(0, coord.shape[0]):
txt = plot.text(coord[x][line]+0.03,
coord[y][line],
coord.index[line],
)
txts.append(txt)
if adjust_labels == True:
adjust_text(txts)
#derniers réglages :
#*on supprime l'encadré autour du graphique
#on déplace la légende à l'extérieur du graphique
sns.despine()
sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))
return plot
import pandas
import prince
dataset = pandas.read_csv('https://archive.ics.uci.edu/ml/machine-learning-databases/balloons/adult+stretch.data')
dataset.columns = ['Color', 'Size', 'Action', 'Age', 'Inflated']
mca = prince.MCA(
n_components=20,
n_iter=100,
copy=True,
check_input=True,
engine='sklearn',
random_state=42
)
mca = mca.fit(dataset)
plot_acm(model=mca,
data=dataset,
x=0, y=1, dims = (11, 10),
adjust_labels=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment