Last active
September 8, 2017 06:08
-
-
Save Ailuropoda1864/fad39099c62c6ab9c2efc606f5db554a to your computer and use it in GitHub Desktop.
A function to plot a scatter plot of two variables plus a linearly fitted line.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import scipy.stats as stats | |
def scatter_plot_with_linear_fit(x, y, slope=None, y_intercept=None): | |
""" | |
:param x: an array | |
:param y: an array | |
:param slope: slope of the fitted line | |
:param y_intercept: y-intercept of the fitted line | |
If slope or y_intercept is not specified, these parameters will be generated | |
by linear fit. | |
:return: Pearson correlation coefficient and p-value | |
""" | |
plt.scatter(x, y, alpha=0.8) | |
# fitted line | |
if slope is None or y_intercept is None: | |
slope, y_intercept = np.polyfit(x, y, 1) | |
x_fit = np.linspace(np.min(x), np.max(x), 100) | |
y_fit = slope * x_fit + y_intercept | |
plt.plot(x_fit, y_fit, linestyle='dashed', color='black', alpha=0.5) | |
return stats.pearsonr(x, y) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment