Skip to content

Instantly share code, notes, and snippets.

@JoshCheek
Last active January 4, 2018 14:28
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save JoshCheek/a11fbb71e5d81ee79d4e to your computer and use it in GitHub Desktop.
Save JoshCheek/a11fbb71e5d81ee79d4e to your computer and use it in GitHub Desktop.
Overfit neural network to find prime numbers
def is_prime?(n)
bits = 8.times.map { |i| n[i] }
[[[-8.8, 3.0, 9.2, 5.0, -4.8, 5.6, -5.8, 6.1, -6.1],
[-1.8, 2.3, -2.6, -5.6, 0.1, 6.0, -4.7, -5.7, -3.0],
[-3.6, -4.1, 6.5, -0.6, -2.8, -2.6, 2.4, -2.3, 1.0],
[ 2.4, -9.1, -3.1, 7.8, -3.7, -8.9, -2.8, 5.6, 6.3],
[-2.0, 4.0, 11.0, 3.3, -6.0, 0.7, -7.0, 1.6, -3.0],
[ 0.7, -7.2, 2.8, 4.5, -3.6, -1.5, 2.7, -0.1, -3.5],
[-8.9, 5.5, 4.8, -4.1, 5.6, 4.8, 5.5, -4.0, -6.4],
[-4.0, -6.7, -3.6, 5.5, 2.5, -6.9, 7.8, -4.1, -0.4],
[ 1.1, -3.5, -6.7, -2.8, 2.3, -4.7, 6.3, -4.6, -2.2]
],
[[-14.6, -13.9, -9.4, -10.8, 13.7, 10.6, -10.7, -11.0, 12.7, 5.2]],
]
.inject(bits) { |nodes_lhs, synapseses_rhs_major|
synapseses_rhs_major.map { |synapses|
total_input = [*nodes_lhs, 1.0].zip(synapses).map { |n, s| n*s }.inject(0, :+)
1 / (1 + Math::E ** -total_input) # https://en.wikipedia.org/wiki/Sigmoid_function
}
}.first > 0.5
end
require 'prime'
(256).times.all? { |n| is_prime?(n) == Prime.prime?(n) } # => true
require 'ai4r/neural_network/backpropagation'
require 'prime'
# srand 1
def bits_for(n)
[n[0],n[1],n[2],n[3],n[4],n[5],n[6],n[7]]
end
def train(net, n)
net.train bits_for(n), [n.prime? ? 1 : 0]
end
loop do
puts "Training the network, please wait.";
max = (2**8)-1
net = Ai4r::NeuralNetwork::Backpropagation.new([8, 9, 1]);
2001.times do |i|
errors = 0.upto(max).to_a.shuffle.map { |n| [n, train(net, n)] }
next unless i % 200 == 0
average = errors.map(&:last).inject(:+) / errors.length
puts "Error for run #{i}: #{average.inspect}"
next if i < 600
# retrain = errors.select { |n, err| err > 0.2 }
retrain = 0.upto(max).reject { |n| (0.5 < net.eval(bits_for n).first) == n.prime? }
puts " retraining on #{retrain.inspect}"
500.times { retrain.each { |n| train net, n } }
end
puts "Test data"
num_correct = 0
num_attempted = 0
0.upto max do |n|
num_attempted += 1
result = net.eval(bits_for n).first
actual = (0.5 <= result)
expected = n.prime?
(expected == actual) ? (num_correct += 1) : printf("%3d \e[31mreal:%-5s ai:%-5s\e[0m (%f)\n", n, expected, actual, result)
end
puts
puts "SUMMARY: #{num_correct}/#{num_attempted} (#{100.0*num_correct/num_attempted}%)"
if num_attempted == num_correct
require "pry"
binding.pry
end
puts "-----------------------------"
end
# require 'pp'
# puts '[' << synapseses_rhs_major.map { |s| s.map { |s| '%4.1f' % s }.join(', ') }.join("],\n")
# 8 bits
# bits = [n[0], n[1], n[2], n[3], n[4], n[5], n[6], n[7]]
# [[[-3.2, 3.6, -0.7, 2.8, 0.9, 3.2, -0.1, 0.4, -2.3],
# [-0.2, -3.0, 0.9, -3.8, 1.3, 3.1, 1.0, -2.4, 0.2],
# [-1.0, 1.4, -0.0, -3.5, 1.6, 2.4, -2.9, 4.0, -1.0],
# [-0.6, 3.0, 1.9, 2.4, 0.4, 0.4, 2.7, 0.6, -0.9],
# [-4.1, -1.6, 2.8, -1.6, 0.6, -1.3, 3.3, 1.2, -0.5],
# [ 0.0, 3.8, -5.0, 4.1, -5.2, -1.1, -5.9, 4.6, 1.3],
# [-0.2, 2.0, -1.9, -1.0, -2.5, 6.2, -2.6, -0.3, -1.3],
# [-0.3, -0.7, 0.6, -0.8, 1.3, -0.6, 1.9, -2.2, -0.2],
# [-2.3, -0.3, 3.4, -2.6, 2.6, -1.2, 2.5, -2.1, -3.2],
# [ 0.8, -0.9, -4.4, 3.4, -1.6, 2.5, -0.6, -2.1, -0.9],
# [ 3.2, 3.7, 1.6, -3.8, -1.7, 4.5, -4.6, 3.2, -0.3],
# [-0.3, 0.7, 1.1, -0.3, -0.8, -0.5, -0.3, -0.8, -0.4],
# [-1.5, 1.7, -2.5, 0.3, -1.3, 2.0, -2.5, 0.4, 0.3],
# [ 2.9, -2.6, -1.7, 0.0, 1.4, -2.1, -1.7, 1.7, 0.9],
# [ 0.1, -2.5, 3.2, -1.7, 1.4, -2.5, 3.1, -1.6, -0.1],
# [-3.6, -2.1, -0.5, 4.1, -1.6, 0.6, -1.3, 4.4, -0.8],
# ],
# [[ 0.7, 1.1, -0.6, -0.6, 0.3, -0.7, 0.2, 1.0, -0.2, -0.7, -1.0, -0.9, 0.4, -0.9, -0.8, -0.5, -1.1],
# [-0.6, -0.9, -0.3, -0.9, 0.2, -0.1, 0.0, 0.3, 0.8, -0.7, 0.5, -0.4, -0.3, -0.8, -0.3, -0.5, -0.4],
# [-1.6, -0.5, -1.1, 0.2, 0.2, -1.5, 0.5, -0.7, 0.2, -0.9, 0.6, -0.3, 0.3, 0.0, 0.4, -0.1, -1.4],
# [ 2.0, 3.3, -0.1, -1.3, 3.2, 7.0, -6.3, 0.5, 4.6, -3.8, -5.1, 1.3, -1.0, 1.5, 0.3, -0.7, -0.8],
# [-1.5, 0.5, -0.5, -0.1, -0.9, -1.5, -0.6, 0.8, -0.6, -1.1, 1.3, -0.3, -0.2, 0.2, 0.8, -1.4, -0.5],
# [-2.1, 1.2, 1.2, -3.7, 0.3, 2.0, -1.9, 1.1, 1.6, 0.8, -0.8, 0.1, -0.0, -0.3, 1.0, -0.4, -0.2],
# [ 1.1, 2.4, 4.2, -1.1, 2.2, -7.7, 2.3, 1.1, 2.0, 3.5, -3.4, 0.5, 3.1, -3.7, 0.8, 5.4, 0.2],
# [-5.1, 1.4, -1.5, 1.4, -0.6, -5.4, -2.1, 2.1, 1.1, -1.5, 1.0, 0.1, -2.6, 0.3, 5.1, -2.7, -1.4],
# ],
# [[-0.7, 0.2, 0.9, -12.0, 1.4, -4.3, -11.1, 9.8, 5.4]],
# ]
# 7 bits
# bits = [n[0], n[1], n[2], n[3], n[4], n[5], n[6]]
# [[[-5.8, 3.6, 1.9, 1.1, 3.5, 2.1, 3.3, -4.9],
# [-3.0, 5.4, 6.8, 2.7, -5.6, 4.6, -6.4, -1.6],
# [-5.1, 0.2, 4.5, 2.3, -1.9, -2.3, 0.0, -1.0],
# [ 0.2, 2.2, -2.7, 4.8, 0.1, -3.0, -1.7, -0.7],
# [-2.6, 4.3, -4.8, 3.6, -3.2, 3.3, -6.0, 2.6],
# [-3.0, 2.5, 1.1, -4.1, 1.5, 6.5, -3.3, -4.1],
# [-1.7, -5.8, 2.2, 6.2, -0.5, -6.8, 3.7, 2.4]],
# [[-8.0, 9.6, -7.6, 8.6, -12.3, -9.4, -9.6, 5.4]],
# ]
# 8 bits
# bits = [n[0], n[1], n[2], n[3], n[4], n[5], n[6], n[7]] # ~> NameError: undefined local variable or method `n' for main:Object
# [[[-8.8, 3.0, 9.2, 5.0, -4.8, 5.6, -5.8, 6.1, -6.1],
# [-1.8, 2.3, -2.6, -5.6, 0.1, 6.0, -4.7, -5.7, -3.0],
# [-3.6, -4.1, 6.5, -0.6, -2.8, -2.6, 2.4, -2.3, 1.0],
# [ 2.4, -9.1, -3.1, 7.8, -3.7, -8.9, -2.8, 5.6, 6.3],
# [-2.0, 4.0, 11.0, 3.3, -6.0, 0.7, -7.0, 1.6, -3.0],
# [ 0.7, -7.2, 2.8, 4.5, -3.6, -1.5, 2.7, -0.1, -3.5],
# [-8.9, 5.5, 4.8, -4.1, 5.6, 4.8, 5.5, -4.0, -6.4],
# [-4.0, -6.7, -3.6, 5.5, 2.5, -6.9, 7.8, -4.1, -0.4],
# [ 1.1, -3.5, -6.7, -2.8, 2.3, -4.7, 6.3, -4.6, -2.2]
# ],
# [[-14.6, -13.9, -9.4, -10.8, 13.7, 10.6, -10.7, -11.0, 12.7, 5.2]],
# ]
def is_prime?(n)
[[[-9, 3, 9, 5, -5, 6, -6, 6, -6],
[-2, 2, -3, -6, 0, 6, -5, -6, -3],
[-4, -4, 7, 0, -3, -3, 2, -2, 1],
[ 2, -9, -3, 8, -4, -9, -3, 6, 6],
[-2, 4, 11, 3, -6, 1, -7, 2, -3],
[ 1, -7, 3, 4, -4, -2, 3, 0, -3.5],
[-9, 6, 5, -4, 6, 5, 5, -4, -6],
[-4, -7, -4, 5, 2, -7, 8, -4, 0],
[ 1, -3, -7, -3, 2, -5, 6, -5, -2]],
[[-15,-14,-9,-11,14,11,-11,-11,11,5]]]
.inject(8.times.map{|i|n[i]}){|l,r|r.map{|s|1/(1+Math::E**[*l,1].zip(s).inject(0){|x,(n,s)|x-n*s})}}[0]>0.5
end
require 'prime'
(256).times.all? { |n| is_prime?(n) == n.prime? } # => true
@tarvos21
Copy link

tarvos21 commented Jan 4, 2018

How large numbers can the algorithm take?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment