Skip to content

Instantly share code, notes, and snippets.

Created September 24, 2013 21:20
Show Gist options
  • Save agness/6691423 to your computer and use it in GitHub Desktop.
Save agness/6691423 to your computer and use it in GitHub Desktop.
Logistic Regression with Stochastic Gradient Descent.
Copyright (c) 2009, Naoaki Okazaki
This code illustrates an implementation of logistic regression models
trained by Stochastic Gradient Decent (SGD).
This program reads a training set from STDIN, trains a logistic regression
model, evaluates the model on a test set (given by the first argument) if
specified, and outputs the feature weights to STDOUT. This is the typical
usage of this problem:
$ ./ test.txt < train.txt
Each line in a data set represents an instance that consists of binary
features and label separated by TAB characters. This is the BNF notation
of the data format:
<line> ::= <label> ('\t' <feature>)+ '\n'
<label> ::= '1' | '0'
<feature> ::= <string>
The following topics are not covered for simplicity:
- bias term
- regularization
- real-valued features
- multiclass logistic regression (maximum entropy model)
- two or more iterations for training
- calibration of learning rate
This code requires Python 2.5 or later for collections.defaultdict().
import collections
import math
import sys
N = 17997 # Change this to present the number of training instances.
eta0 = 0.1 # Initial learning rate; change this if desired.
def update(W, X, l, eta):
# Compute the inner product of features and their weights.
a = sum([W[x] for x in X])
# Compute the gradient of the error function (avoiding +Inf overflow).
g = ((1. / (1. + math.exp(-a))) - l) if -100. < a else (0. - l)
# Update the feature weights by Stochastic Gradient Descent.
for x in X:
W[x] -= eta * g
def train(fi):
t = 1
W = collections.defaultdict(float)
# Loop for instances.
for line in fi:
fields = line.strip('\n').split('\t')
update(W, fields[1:], float(fields[0]), eta0 / (1 + t / float(N)))
t += 1
return W
def classify(W, X):
return 1 if 0. < sum([W[x] for x in X]) else 0
def test(W, fi):
m = 0
n = 0
for line in fi:
fields = line.strip('\n').split('\t')
l = classify(W, fields[1:])
m += (1 - (l ^ int(fields[0])))
n += 1
print('Accuracy = %f (%d/%d)' % (m / float(n), m, n))
if __name__ == '__main__':
W = train(sys.stdin)
if 1 < len(sys.argv):
test(W, open(sys.argv[1]))
for name, value in W.iteritems():
print('%f\t%s' % (value, name))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment