Created
December 4, 2023 14:39
-
-
Save danilotat/0b72125024d02b70841a154ac49c0902 to your computer and use it in GitHub Desktop.
Plot a boxplot comparing a quantitative (dependent) variable in three groups (defined by the independent variable) applying most common statistical tests. Pvalues or stars like in PRISM.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from decimal import Decimal | |
from scipy import stats | |
import pandas as pd | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
import matplotlib.font_manager as fm | |
import itertools | |
def plotBoxplotWithTest3Groups(df: pd.DataFrame, depVar: str, indVar: str, ax: plt.Axes, test: str, prop=None, explicitP=False) -> plt.Axes: | |
""" | |
Plot a Boxplot with a statistical test between three groups. | |
The test can be chosen between Kruskal-Wallis, Mann-Whitney or T-test. | |
The p-value can be shown explicitly or with stars. | |
Parameters: | |
- df: pd.DataFrame | |
- depVar: str, the dependent variable | |
- indVar: str, the independent variable. Has 3 diffeent values | |
- ax: plt.Axes, an existing axes. This is made for better shape control | |
- test: str, the test to perform. Can be 'kruskal', 'mann-whitney' or 'ttest' | |
- prop: matplotlib.font_manager.FontProperties, optional. Font properties for the text | |
- explicitP: bool, whether to show the p-value explicitly or with stars | |
Returns: | |
plt.Axes | |
""" | |
vals = df[depVar].unique().tolist() | |
if len(vals) != 3: | |
raise ValueError("The dependent variable must have 3 different values") | |
sns.boxplot( | |
data=df, x=depVar, y=indVar, | |
palette={ | |
vals[0]: "#EF3B2C", | |
vals[1]: "#20A39E", | |
vals[2]: "#F2A900" | |
}, ax=ax, width=.5, showfliers=False) | |
IQR = df[indVar].quantile(.75) - df[indVar].quantile(.25) | |
# get value until 1.5 times the interquartile range | |
max_val = df[indVar].median() + (IQR*2) | |
statTest = None | |
if test.upper() == "KRUSKAL": | |
statTest = stats.kruskal | |
if test.upper() == "MANN-WHITNEY": | |
statTest = stats.mannwhitneyu | |
if test.upper() == "TTEST": | |
statTest = stats.ttest_ind | |
if not test.upper() in ["KRUSKAL", "MANN-WHITNEY", "TTEST"]: | |
raise ValueError("Test not recognized. Please choose between 'kruskal', 'mann-whitney' or 'ttest'") | |
# now iterate through the values of the dependent variable two by two and performs the test | |
# get the combinations | |
combinations = list(itertools.combinations(vals, 2)) | |
# get the y range | |
yrange = ax.get_ylim()[1] | |
for i in range(len(combinations)): | |
print(combinations[i]) | |
value1 = combinations[i][0] | |
value2 = combinations[i][1] | |
# get indices from vals | |
indices = [vals.index(value1), vals.index(value2)] | |
# get the max value using a subset of the dataframe made by the two values | |
# use only the column which has the highest value | |
IQR = df.loc[df[depVar].isin([value1, value2])][indVar].quantile(.75) - df.loc[df[depVar].isin([value1, value2])][indVar].quantile(.25) | |
max_val = df.loc[df[depVar].isin([value1, value2])][indVar].median() + (IQR*1.5) | |
# do the test | |
pval = statTest(df.loc[df[depVar] == value1][indVar],df.loc[df[depVar] == value2][indVar], nan_policy='omit')[1] | |
ax.plot( | |
[indices[0], indices[0], indices[1], indices[1]], | |
[max_val+(yrange*0.04), max_val+(yrange*0.05), max_val+(yrange*0.05), max_val+(yrange*0.04)], | |
color='0.2') | |
toWrite = "" | |
if not explicitP: | |
if pval < 0.0001: | |
toWrite = "***" | |
elif pval < 0.001: | |
toWrite = "**" | |
elif pval > 0.001 and pval < 0.05: | |
toWrite = "*" | |
else: | |
toWrite = "ns" | |
ax.text( | |
x=(indices[0] + indices[1])/2, y=max_val+(yrange*0.05), s=toWrite, ha= "center", size=14, weight="bold", color="0.2") | |
else: | |
pval = Decimal(pval) | |
toWrite = "p: {:.2e}".format(pval) | |
ax.text( | |
x=(indices[0] + indices[1])/2, y=max_val+(yrange*0.08), s=toWrite, ha="center", size=12, color="0.2") | |
ax.spines["top"].set_visible(False) | |
ax.spines["right"].set_visible(False) | |
for axis in ["bottom","left"]: | |
ax.spines[axis].set_linewidth(2) | |
ax.spines[axis].set_color('0.2') | |
for item in ([ax.xaxis.label, ax.yaxis.label] + ax.get_xticklabels() + ax.get_yticklabels()): | |
item.set_fontproperties(prop) | |
return ax |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment