Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Overfit neural network to find prime numbers
def is_prime?(n)
bits = 8.times.map { |i| n[i] }
[[[-8.8, 3.0, 9.2, 5.0, -4.8, 5.6, -5.8, 6.1, -6.1],
[-1.8, 2.3, -2.6, -5.6, 0.1, 6.0, -4.7, -5.7, -3.0],
[-3.6, -4.1, 6.5, -0.6, -2.8, -2.6, 2.4, -2.3, 1.0],
[ 2.4, -9.1, -3.1, 7.8, -3.7, -8.9, -2.8, 5.6, 6.3],
[-2.0, 4.0, 11.0, 3.3, -6.0, 0.7, -7.0, 1.6, -3.0],
[ 0.7, -7.2, 2.8, 4.5, -3.6, -1.5, 2.7, -0.1, -3.5],
[-8.9, 5.5, 4.8, -4.1, 5.6, 4.8, 5.5, -4.0, -6.4],
[-4.0, -6.7, -3.6, 5.5, 2.5, -6.9, 7.8, -4.1, -0.4],
[ 1.1, -3.5, -6.7, -2.8, 2.3, -4.7, 6.3, -4.6, -2.2]
],
[[-14.6, -13.9, -9.4, -10.8, 13.7, 10.6, -10.7, -11.0, 12.7, 5.2]],
]
.inject(bits) { |nodes_lhs, synapseses_rhs_major|
synapseses_rhs_major.map { |synapses|
total_input = [*nodes_lhs, 1.0].zip(synapses).map { |n, s| n*s }.inject(0, :+)
1 / (1 + Math::E ** -total_input) # https://en.wikipedia.org/wiki/Sigmoid_function
}
}.first > 0.5
end
require 'prime'
(256).times.all? { |n| is_prime?(n) == Prime.prime?(n) } # => true
require 'ai4r/neural_network/backpropagation'
require 'prime'
# srand 1
def bits_for(n)
[n[0],n[1],n[2],n[3],n[4],n[5],n[6],n[7]]
end
def train(net, n)
net.train bits_for(n), [n.prime? ? 1 : 0]
end
loop do
puts "Training the network, please wait.";
max = (2**8)-1
net = Ai4r::NeuralNetwork::Backpropagation.new([8, 9, 1]);
2001.times do |i|
errors = 0.upto(max).to_a.shuffle.map { |n| [n, train(net, n)] }
next unless i % 200 == 0
average = errors.map(&:last).inject(:+) / errors.length
puts "Error for run #{i}: #{average.inspect}"
next if i < 600
# retrain = errors.select { |n, err| err > 0.2 }
retrain = 0.upto(max).reject { |n| (0.5 < net.eval(bits_for n).first) == n.prime? }
puts " retraining on #{retrain.inspect}"
500.times { retrain.each { |n| train net, n } }
end
puts "Test data"
num_correct = 0
num_attempted = 0
0.upto max do |n|
num_attempted += 1
result = net.eval(bits_for n).first
actual = (0.5 <= result)
expected = n.prime?
(expected == actual) ? (num_correct += 1) : printf("%3d \e[31mreal:%-5s ai:%-5s\e[0m (%f)\n", n, expected, actual, result)
end
puts
puts "SUMMARY: #{num_correct}/#{num_attempted} (#{100.0*num_correct/num_attempted}%)"
if num_attempted == num_correct
require "pry"
binding.pry
end
puts "-----------------------------"
end
# require 'pp'
# puts '[' << synapseses_rhs_major.map { |s| s.map { |s| '%4.1f' % s }.join(', ') }.join("],\n")
# 8 bits
# bits = [n[0], n[1], n[2], n[3], n[4], n[5], n[6], n[7]]
# [[[-3.2, 3.6, -0.7, 2.8, 0.9, 3.2, -0.1, 0.4, -2.3],
# [-0.2, -3.0, 0.9, -3.8, 1.3, 3.1, 1.0, -2.4, 0.2],
# [-1.0, 1.4, -0.0, -3.5, 1.6, 2.4, -2.9, 4.0, -1.0],
# [-0.6, 3.0, 1.9, 2.4, 0.4, 0.4, 2.7, 0.6, -0.9],
# [-4.1, -1.6, 2.8, -1.6, 0.6, -1.3, 3.3, 1.2, -0.5],
# [ 0.0, 3.8, -5.0, 4.1, -5.2, -1.1, -5.9, 4.6, 1.3],
# [-0.2, 2.0, -1.9, -1.0, -2.5, 6.2, -2.6, -0.3, -1.3],
# [-0.3, -0.7, 0.6, -0.8, 1.3, -0.6, 1.9, -2.2, -0.2],
# [-2.3, -0.3, 3.4, -2.6, 2.6, -1.2, 2.5, -2.1, -3.2],
# [ 0.8, -0.9, -4.4, 3.4, -1.6, 2.5, -0.6, -2.1, -0.9],
# [ 3.2, 3.7, 1.6, -3.8, -1.7, 4.5, -4.6, 3.2, -0.3],
# [-0.3, 0.7, 1.1, -0.3, -0.8, -0.5, -0.3, -0.8, -0.4],
# [-1.5, 1.7, -2.5, 0.3, -1.3, 2.0, -2.5, 0.4, 0.3],
# [ 2.9, -2.6, -1.7, 0.0, 1.4, -2.1, -1.7, 1.7, 0.9],
# [ 0.1, -2.5, 3.2, -1.7, 1.4, -2.5, 3.1, -1.6, -0.1],
# [-3.6, -2.1, -0.5, 4.1, -1.6, 0.6, -1.3, 4.4, -0.8],
# ],
# [[ 0.7, 1.1, -0.6, -0.6, 0.3, -0.7, 0.2, 1.0, -0.2, -0.7, -1.0, -0.9, 0.4, -0.9, -0.8, -0.5, -1.1],
# [-0.6, -0.9, -0.3, -0.9, 0.2, -0.1, 0.0, 0.3, 0.8, -0.7, 0.5, -0.4, -0.3, -0.8, -0.3, -0.5, -0.4],
# [-1.6, -0.5, -1.1, 0.2, 0.2, -1.5, 0.5, -0.7, 0.2, -0.9, 0.6, -0.3, 0.3, 0.0, 0.4, -0.1, -1.4],
# [ 2.0, 3.3, -0.1, -1.3, 3.2, 7.0, -6.3, 0.5, 4.6, -3.8, -5.1, 1.3, -1.0, 1.5, 0.3, -0.7, -0.8],
# [-1.5, 0.5, -0.5, -0.1, -0.9, -1.5, -0.6, 0.8, -0.6, -1.1, 1.3, -0.3, -0.2, 0.2, 0.8, -1.4, -0.5],
# [-2.1, 1.2, 1.2, -3.7, 0.3, 2.0, -1.9, 1.1, 1.6, 0.8, -0.8, 0.1, -0.0, -0.3, 1.0, -0.4, -0.2],
# [ 1.1, 2.4, 4.2, -1.1, 2.2, -7.7, 2.3, 1.1, 2.0, 3.5, -3.4, 0.5, 3.1, -3.7, 0.8, 5.4, 0.2],
# [-5.1, 1.4, -1.5, 1.4, -0.6, -5.4, -2.1, 2.1, 1.1, -1.5, 1.0, 0.1, -2.6, 0.3, 5.1, -2.7, -1.4],
# ],
# [[-0.7, 0.2, 0.9, -12.0, 1.4, -4.3, -11.1, 9.8, 5.4]],
# ]
# 7 bits
# bits = [n[0], n[1], n[2], n[3], n[4], n[5], n[6]]
# [[[-5.8, 3.6, 1.9, 1.1, 3.5, 2.1, 3.3, -4.9],
# [-3.0, 5.4, 6.8, 2.7, -5.6, 4.6, -6.4, -1.6],
# [-5.1, 0.2, 4.5, 2.3, -1.9, -2.3, 0.0, -1.0],
# [ 0.2, 2.2, -2.7, 4.8, 0.1, -3.0, -1.7, -0.7],
# [-2.6, 4.3, -4.8, 3.6, -3.2, 3.3, -6.0, 2.6],
# [-3.0, 2.5, 1.1, -4.1, 1.5, 6.5, -3.3, -4.1],
# [-1.7, -5.8, 2.2, 6.2, -0.5, -6.8, 3.7, 2.4]],
# [[-8.0, 9.6, -7.6, 8.6, -12.3, -9.4, -9.6, 5.4]],
# ]
# 8 bits
# bits = [n[0], n[1], n[2], n[3], n[4], n[5], n[6], n[7]] # ~> NameError: undefined local variable or method `n' for main:Object
# [[[-8.8, 3.0, 9.2, 5.0, -4.8, 5.6, -5.8, 6.1, -6.1],
# [-1.8, 2.3, -2.6, -5.6, 0.1, 6.0, -4.7, -5.7, -3.0],
# [-3.6, -4.1, 6.5, -0.6, -2.8, -2.6, 2.4, -2.3, 1.0],
# [ 2.4, -9.1, -3.1, 7.8, -3.7, -8.9, -2.8, 5.6, 6.3],
# [-2.0, 4.0, 11.0, 3.3, -6.0, 0.7, -7.0, 1.6, -3.0],
# [ 0.7, -7.2, 2.8, 4.5, -3.6, -1.5, 2.7, -0.1, -3.5],
# [-8.9, 5.5, 4.8, -4.1, 5.6, 4.8, 5.5, -4.0, -6.4],
# [-4.0, -6.7, -3.6, 5.5, 2.5, -6.9, 7.8, -4.1, -0.4],
# [ 1.1, -3.5, -6.7, -2.8, 2.3, -4.7, 6.3, -4.6, -2.2]
# ],
# [[-14.6, -13.9, -9.4, -10.8, 13.7, 10.6, -10.7, -11.0, 12.7, 5.2]],
# ]
def is_prime?(n)
[[[-9, 3, 9, 5, -5, 6, -6, 6, -6],
[-2, 2, -3, -6, 0, 6, -5, -6, -3],
[-4, -4, 7, 0, -3, -3, 2, -2, 1],
[ 2, -9, -3, 8, -4, -9, -3, 6, 6],
[-2, 4, 11, 3, -6, 1, -7, 2, -3],
[ 1, -7, 3, 4, -4, -2, 3, 0, -3.5],
[-9, 6, 5, -4, 6, 5, 5, -4, -6],
[-4, -7, -4, 5, 2, -7, 8, -4, 0],
[ 1, -3, -7, -3, 2, -5, 6, -5, -2]],
[[-15,-14,-9,-11,14,11,-11,-11,11,5]]]
.inject(8.times.map{|i|n[i]}){|l,r|r.map{|s|1/(1+Math::E**[*l,1].zip(s).inject(0){|x,(n,s)|x-n*s})}}[0]>0.5
end
require 'prime'
(256).times.all? { |n| is_prime?(n) == n.prime? } # => true
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment