Skip to content

Instantly share code, notes, and snippets.

@davidandrzej
Created May 3, 2011 18:32
Show Gist options
  • Save davidandrzej/953913 to your computer and use it in GitHub Desktop.
Save davidandrzej/953913 to your computer and use it in GitHub Desktop.
Augment 2D scatter plot with linear regression
"""
Augment scatter plot with linear regression fit
David Andrzejewski
"""
import numpy as NP
import numpy.random as NPR
import matplotlib.pyplot as P
import matplotlib.lines as L
from scikits.learn.linear_model import LinearRegression
def linearscatter(xpts, ypts, ax=None, **kwargs):
"""
Augment scatter plot with linear regression
Unused kwargs will be passed along to .scatter()
"""
if(ax == None):
ax = P.figure().gca()
# Scatter plot
P.scatter(xpts, ypts, axes=ax, **kwargs)
# Get ordinary least squares fit
model = LinearRegression()
model.fit(NP.reshape(xpts, (len(xpts),1)), ypts)
# Plot line over scatter
miny = xpts.min() * model.coef_[0] + model.intercept_
maxy = xpts.max() * model.coef_[0] + model.intercept_
ax.add_line(L.Line2D([xpts.min(), xpts.max()],
[miny, maxy],
color='r', linewidth=5))
return (ax, model)
if __name__ == '__main__':
# Generate a synthetic test dataset
npts = 200
xpts = NPR.uniform(1.0, 10.0, (npts,))
coeff = 1.0
noise = NPR.standard_normal((len(xpts),))
ypts = (xpts * coeff) + noise
# Display it
(ax, model) = linearscatter(xpts, ypts, ax=None)
ax.set_xlabel('X')
ax.set_ylabel('Y')
r2 = model.score(NP.reshape(xpts, (len(xpts), 1)), ypts)
ax.set_title('Linear regression (R2 = %.2f)' % r2)
P.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment