Skip to content

Instantly share code, notes, and snippets.

@szolotykh
Created October 16, 2014 22:58
Show Gist options
  • Save szolotykh/2d14c5be03cf70e15390 to your computer and use it in GitHub Desktop.
Save szolotykh/2d14c5be03cf70e15390 to your computer and use it in GitHub Desktop.
Neural network test: Integer classification
from pybrain.tools.shortcuts import buildNetwork
from pybrain.structure import FeedForwardNetwork
from pybrain.datasets import SupervisedDataSet
from pybrain.supervised.trainers import BackpropTrainer
from pybrain.structure import LinearLayer, SigmoidLayer, TanhLayer
from pybrain.structure import FullConnection
import random
def int2bin(a):
arr = [int(x) for x in bin(a)[2:]]
return [0]*(4-len(arr)) + arr
# Network
net = FeedForwardNetwork()
# Layers
inLayer = LinearLayer(4)
hiddenLayer1 = SigmoidLayer(8)
hiddenLayer2 = SigmoidLayer(16)
hiddenLayer3 = SigmoidLayer(32)
hiddenLayer4 = SigmoidLayer(8)
outLayer = LinearLayer(1)
net.addInputModule(inLayer)
net.addModule(hiddenLayer1)
net.addModule(hiddenLayer2)
net.addModule(hiddenLayer3)
net.addModule(hiddenLayer4)
net.addOutputModule(outLayer)
# Connection
in_to_hidden1 = FullConnection(inLayer, hiddenLayer1)
hidden1_to_hidden2 = FullConnection(hiddenLayer1, hiddenLayer2)
hidden2_to_hidden3 = FullConnection(hiddenLayer2, hiddenLayer3)
hidden3_to_hidden4 = FullConnection(hiddenLayer3, hiddenLayer4)
hidden4_to_out = FullConnection(hiddenLayer4, outLayer)
net.addConnection(in_to_hidden1)
net.addConnection(hidden1_to_hidden2)
net.addConnection(hidden2_to_hidden3)
net.addConnection(hidden3_to_hidden4)
net.addConnection(hidden4_to_out)
# init network
net.sortModules()
# Data set
ds = SupervisedDataSet(4, 1)
#numbers = []
#for num in range(0,10):
# numbers += [int(random.randint(0, 15))]
numbers = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,15]
for num in numbers:
ds.addSample(tuple(int2bin(num)), (num,))
#print numbers
print ds
# Traning
trainer = BackpropTrainer(net, ds)
if True:
i = 0
while i < 100:
err = trainer.train()
print "Traning error:", err
i=i+1
else:
err = trainer.trainUntilConvergence()
# Activation
for i in range(16):
print i, net.activate(int2bin(i))
#trainer.trainUntilConvergence()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment