Skip to content

Instantly share code, notes, and snippets.

@dgk
Created December 22, 2011 18:21
Show Gist options
  • Save dgk/1511294 to your computer and use it in GitHub Desktop.
Save dgk/1511294 to your computer and use it in GitHub Desktop.
pybrain implementation of M. Tim Jones AI Application Programming chapter 5 sample code
# -*- coding: utf-8 -*-
from pybrain.tools.shortcuts import buildNetwork
from pybrain.datasets.supervised import SupervisedDataSet
from pybrain.supervised.trainers.backprop import BackpropTrainer
from pybrain.structure.modules.softmax import SoftmaxLayer
from pybrain.structure.modules.tanhlayer import TanhLayer
actions = ("Attack", "Run", "Wander", "Hide")
inputs = ('Health', 'Knife', 'Gun', 'Enemy')
samples = (
((2.0, 0.0, 0.0, 0.0), (0.0, 0.0, 1.0, 0.0)),
((2.0, 0.0, 0.0, 1.0), (0.0, 0.0, 1e0, 0.0)),
((2.0, 0.0, 1.0, 1.0), (1.0, 0.0, 0.0, 0.0)),
((2.0, 0.0, 1.0, 2.0), (1.0, 0.0, 0.0, 0.0)),
((2.0, 1.0, 0.0, 2.0), (0.0, 0.0, 0.0, 1.0)),
((2.0, 1.0, 0.0, 1.0), (1.0, 0.0, 0.0, 0.0)),
((1.0, 0.0, 0.0, 0.0), (0.0, 0.0, 1.0, 0.0)),
((1.0, 0.0, 0.0, 1.0), (0.0, 0.0, 0.0, 1.0)),
((1.0, 0.0, 1.0, 1.0), (1.0, 0.0, 0.0, 0.0)),
((1.0, 0.0, 1.0, 2.0), (0.0, 0.0, 0.0, 1.0)),
((1.0, 1.0, 0.0, 2.0), (0.0, 0.0, 0.0, 1.0)),
((1.0, 1.0, 0.0, 1.0), (0.0, 0.0, 0.0, 1.0)),
((0.0, 0.0, 0.0, 0.0), (0.0, 0.0, 1.0, 0.0)),
((0.0, 0.0, 0.0, 1.0), (0.0, 0.0, 0.0, 1.0)),
((0.0, 0.0, 1.0, 1.0), (0.0, 0.0, 0.0, 1.0)),
((0.0, 0.0, 1.0, 2.0), (0.0, 1.0, 0.0, 0.0)),
((0.0, 1.0, 0.0, 2.0), (0.0, 1.0, 0.0, 0.0)),
((0.0, 1.0, 0.0, 1.0), (0.0, 0.0, 0.0, 1.0)),
)
test_data = (
(2.0, 1.0, 1.0, 1.0,),
(1.0, 1.0, 1.0, 2.0,),
(0.0, 0.0, 0.0, 0.0,),
(0.0, 1.0, 1.0, 1.0,),
(2.0, 0.0, 1.0, 3.0,),
(2.0, 1.0, 0.0, 3.0,),
(0.0, 1.0, 0.0, 3.0,),
)
'''
Health = 2 Knife = 1 Gun = 1 Enemy = 1 Wander
Health = 1 Knife = 1 Gun = 1 Enemy = 2 Hide
Health = 0 Knife = 0 Gun = 0 Enemy = 0 Wander
Health = 0 Knife = 1 Gun = 1 Enemy = 1 Hide
Health = 2 Knife = 0 Gun = 1 Enemy = 3 Hide
Health = 2 Knife = 1 Gun = 0 Enemy = 3 Hide
Health = 0 Knife = 1 Gun = 0 Enemy = 3 Run
'''
net = buildNetwork(4, 3, 4, bias=True)
ds = SupervisedDataSet(4, 4)
for sample in samples:
ds.addSample(*sample)
trainer = BackpropTrainer(net, ds, learningrate = 0.1, momentum = 0., )
for _ in range(100):
trainer.trainEpochs(5)
for row in test_data:
print ' '.join(['%s = %s' % (inputs[i], int(x)) for i, x in enumerate(row)]),
res = net.activate(row)
res = list(res)
#print res, res.index(max(res))
print actions[res.index(max(res))]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment