Created
November 18, 2016 06:00
-
-
Save alanbernstein/d34ced4786d24bdc20bc6b923ba33308 to your computer and use it in GitHub Desktop.
fit some curves with quadratic ransac
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/local/bin/python | |
import csv | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from sklearn import linear_model | |
from sklearn.preprocessing import PolynomialFeatures | |
def get_data(fname): | |
with open(fname) as f: | |
reader = csv.DictReader(f) | |
d = [r for r in reader] | |
x = np.array([float(r['sintl2']) for r in d]) | |
y = np.array([float(r['I1/I2%']) for r in d]) | |
return x, y | |
def quadratic_ransac_fit(x, y, name): | |
x_ = x.reshape((-1, 1)) | |
y_ = y.reshape((-1, 1)) | |
xi = np.linspace(min(x), max(x), 100).reshape((-1, 1)) | |
poly_2 = PolynomialFeatures(degree=2) | |
x_2 = poly_2.fit_transform(x_) | |
xi_2 = poly_2.fit_transform(xi) | |
m = linear_model.RANSACRegressor(linear_model.LinearRegression()) | |
m.fit(x_2, y_) | |
yi = m.predict(xi_2) | |
c = m.estimator_.coef_ | |
yi_b = np.dot(c, xi_2.T).T | |
# the coefficients dont include the x^0 term for some reason?? | |
c_b = np.array([float(yi[3][0] - yi_b[0]), c[0, 1], c[0, 2]]) | |
yi_b = np.dot(c_b, xi_2.T).T | |
inlier_mask = m.inlier_mask_ | |
outlier_mask = np.logical_not(inlier_mask) | |
plt.plot(x[inlier_mask], y[inlier_mask], 'k.', label='data') | |
plt.plot(x[outlier_mask], y[outlier_mask], 'r.', label='data (outliers)') | |
plt.plot(xi, yi, label='quadratic ransac') | |
# plt.plot(xi, yi_b, 'r--') | |
plt.title('%s: %0.5f + %0.5fx + %0.5fx^2' % (name, c_b[0], c_b[1], c_b[2])) | |
plt.legend() | |
if __name__ == '__main__': | |
names = ['clq', 'uf6'] | |
n = 1 | |
for name in names: | |
x, y = get_data(name + '.csv') | |
plt.subplot(2, 1, n) | |
quadratic_ransac_fit(x, y, name) | |
n += 1 | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment