Last active
December 30, 2015 06:39
-
-
Save jmmcd/7790588 to your computer and use it in GitHub Desktop.
Snippet for reading in a table of numbers and predicting the last column as a function of the others, using either just a constant, or linear regression, or linear regression regularised with the elastic net. Uses Numpy and Scikit-learn.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
from __future__ import print_function | |
import numpy as np | |
from sklearn.linear_model import ElasticNet, LinearRegression | |
import sys | |
# James McDermott (c) 2013 | |
# Hosted at https://gist.github.com/jmmcd/7790588 | |
# Requires Numpy and Scikit-learn | |
def mae(y, yhat): | |
"""Calculate mean absolute error between inputs.""" | |
return np.mean(np.abs(y - yhat)) | |
def rmse(y, yhat): | |
"""Calculate root mean square error between inputs.""" | |
return np.sqrt(np.mean(np.square(y - yhat))) | |
def get_Xy_train_test(filename, randomise=True, test_proportion=0.5, skip_header=0): | |
"""Read in a table of numbers and split it into X (all columns up | |
to last) and y (last column), then split it into training and | |
testing subsets according to test_proportion. Shuffle if | |
required.""" | |
Xy = np.genfromtxt(filename, skip_header=skip_header) | |
if randomise: | |
np.random.shuffle(Xy) | |
X = Xy[:,:-1] # all columns but last | |
y = Xy[:,-1] # last column | |
idx = int((1.0 - test_proportion) * len(y)) | |
train_X = X[:idx] | |
train_y = y[:idx] | |
test_X = X[idx:] | |
test_y = y[idx:] | |
return train_X, train_y, test_X, test_y | |
def get_Xy_train_test_separate(train_filename, test_filename, skip_header=0): | |
"""Read in training and testing data files, and split each into X | |
(all columns up to last) and y (last column).""" | |
train_Xy = np.genfromtxt(train_filename, skip_header=skip_header) | |
test_Xy = np.genfromtxt(test_filename, skip_header=skip_header) | |
train_X = train_Xy[:,:-1] # all columns but last | |
train_y = train_Xy[:,-1] # last column | |
test_X = test_Xy[:,:-1] # all columns but last | |
test_y = test_Xy[:,-1] # last column | |
return train_X, train_y, test_X, test_y | |
def fit_const(train_X, train_y, test_X, test_y): | |
"""Use the mean of the y training values as a predictor.""" | |
mn = np.mean(train_y) | |
print("Predicting constant", mn) | |
yhat = np.ones(len(train_y)) * mn | |
print("Train error =", error(train_y, yhat)) | |
yhat = np.ones(len(test_y)) * mn | |
print("Test error =", error(test_y, yhat)) | |
def fit_lr(train_X, train_y, test_X, test_y): | |
"""Use linear regression to predict.""" | |
lr = LinearRegression() | |
lr.fit(train_X, train_y) | |
print("LR predicting intercept", lr.intercept_, "and coefs", lr.coef_) | |
yhat = lr.predict(train_X) | |
print("Train error =", error(train_y, yhat)) | |
yhat = lr.predict(test_X) | |
print("Test error =", error(test_y, yhat)) | |
def fit_enet(train_X, train_y, test_X, test_y): | |
"""Use linear regression to predict -- elastic net is LR with L1 | |
and L2 regularisation.""" | |
enet = ElasticNet() | |
enet.fit(train_X, train_y) | |
print("ElasticNet predicting intercept", enet.intercept_, "and coefs", enet.coef_) | |
yhat = enet.predict(train_X) | |
print("Train error =", error(train_y, yhat)) | |
yhat = enet.predict(test_X) | |
print("Test error =", error(test_y, yhat)) | |
if __name__ == "__main__": | |
error = rmse | |
#error = mae | |
if len(sys.argv) == 3: | |
train_filename = sys.argv[1] | |
test_filename = sys.argv[2] | |
train_X, train_y, test_X, test_y = get_Xy_train_test_separate(train_filename, | |
test_filename) | |
else: | |
filename = sys.argv[1] | |
train_X, train_y, test_X, test_y = get_Xy_train_test(filename) | |
fit_const(train_X, train_y, test_X, test_y) | |
fit_lr(train_X, train_y, test_X, test_y) | |
fit_enet(train_X, train_y, test_X, test_y) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment