Skip to content

Instantly share code, notes, and snippets.

@agness
Created September 24, 2013 21:20
Show Gist options
  • Save agness/6691423 to your computer and use it in GitHub Desktop.
Save agness/6691423 to your computer and use it in GitHub Desktop.
"""
Logistic Regression with Stochastic Gradient Descent.
Copyright (c) 2009, Naoaki Okazaki
http://www.chokkan.org/publication/survey/logistic_regression_sgd.html
This code illustrates an implementation of logistic regression models
trained by Stochastic Gradient Decent (SGD).
This program reads a training set from STDIN, trains a logistic regression
model, evaluates the model on a test set (given by the first argument) if
specified, and outputs the feature weights to STDOUT. This is the typical
usage of this problem:
$ ./logistic_regression_sgd.py test.txt < train.txt
Each line in a data set represents an instance that consists of binary
features and label separated by TAB characters. This is the BNF notation
of the data format:
<line> ::= <label> ('\t' <feature>)+ '\n'
<label> ::= '1' | '0'
<feature> ::= <string>
The following topics are not covered for simplicity:
- bias term
- regularization
- real-valued features
- multiclass logistic regression (maximum entropy model)
- two or more iterations for training
- calibration of learning rate
This code requires Python 2.5 or later for collections.defaultdict().
"""
import collections
import math
import sys
N = 17997 # Change this to present the number of training instances.
eta0 = 0.1 # Initial learning rate; change this if desired.
def update(W, X, l, eta):
# Compute the inner product of features and their weights.
a = sum([W[x] for x in X])
# Compute the gradient of the error function (avoiding +Inf overflow).
g = ((1. / (1. + math.exp(-a))) - l) if -100. < a else (0. - l)
# Update the feature weights by Stochastic Gradient Descent.
for x in X:
W[x] -= eta * g
def train(fi):
t = 1
W = collections.defaultdict(float)
# Loop for instances.
for line in fi:
fields = line.strip('\n').split('\t')
update(W, fields[1:], float(fields[0]), eta0 / (1 + t / float(N)))
t += 1
return W
def classify(W, X):
return 1 if 0. < sum([W[x] for x in X]) else 0
def test(W, fi):
m = 0
n = 0
for line in fi:
fields = line.strip('\n').split('\t')
l = classify(W, fields[1:])
m += (1 - (l ^ int(fields[0])))
n += 1
print('Accuracy = %f (%d/%d)' % (m / float(n), m, n))
if __name__ == '__main__':
W = train(sys.stdin)
if 1 < len(sys.argv):
test(W, open(sys.argv[1]))
else:
for name, value in W.iteritems():
print('%f\t%s' % (value, name))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment