{{ message }}

Instantly share code, notes, and snippets.

# tarvos21/_prime_nn.rb

Forked from JoshCheek/_prime_nn.rb
Created Jan 4, 2018
Overfit neural network to find prime numbers
 def is_prime?(n) bits = 8.times.map { |i| n[i] } [[[-8.8, 3.0, 9.2, 5.0, -4.8, 5.6, -5.8, 6.1, -6.1], [-1.8, 2.3, -2.6, -5.6, 0.1, 6.0, -4.7, -5.7, -3.0], [-3.6, -4.1, 6.5, -0.6, -2.8, -2.6, 2.4, -2.3, 1.0], [ 2.4, -9.1, -3.1, 7.8, -3.7, -8.9, -2.8, 5.6, 6.3], [-2.0, 4.0, 11.0, 3.3, -6.0, 0.7, -7.0, 1.6, -3.0], [ 0.7, -7.2, 2.8, 4.5, -3.6, -1.5, 2.7, -0.1, -3.5], [-8.9, 5.5, 4.8, -4.1, 5.6, 4.8, 5.5, -4.0, -6.4], [-4.0, -6.7, -3.6, 5.5, 2.5, -6.9, 7.8, -4.1, -0.4], [ 1.1, -3.5, -6.7, -2.8, 2.3, -4.7, 6.3, -4.6, -2.2] ], [[-14.6, -13.9, -9.4, -10.8, 13.7, 10.6, -10.7, -11.0, 12.7, 5.2]], ] .inject(bits) { |nodes_lhs, synapseses_rhs_major| synapseses_rhs_major.map { |synapses| total_input = [*nodes_lhs, 1.0].zip(synapses).map { |n, s| n*s }.inject(0, :+) 1 / (1 + Math::E ** -total_input) # https://en.wikipedia.org/wiki/Sigmoid_function } }.first > 0.5 end require 'prime' (256).times.all? { |n| is_prime?(n) == Prime.prime?(n) } # => true
 require 'ai4r/neural_network/backpropagation' require 'prime' # srand 1 def bits_for(n) [n[0],n[1],n[2],n[3],n[4],n[5],n[6],n[7]] end def train(net, n) net.train bits_for(n), [n.prime? ? 1 : 0] end loop do puts "Training the network, please wait."; max = (2**8)-1 net = Ai4r::NeuralNetwork::Backpropagation.new([8, 9, 1]); 2001.times do |i| errors = 0.upto(max).to_a.shuffle.map { |n| [n, train(net, n)] } next unless i % 200 == 0 average = errors.map(&:last).inject(:+) / errors.length puts "Error for run #{i}: #{average.inspect}" next if i < 600 # retrain = errors.select { |n, err| err > 0.2 } retrain = 0.upto(max).reject { |n| (0.5 < net.eval(bits_for n).first) == n.prime? } puts " retraining on #{retrain.inspect}" 500.times { retrain.each { |n| train net, n } } end puts "Test data" num_correct = 0 num_attempted = 0 0.upto max do |n| num_attempted += 1 result = net.eval(bits_for n).first actual = (0.5 <= result) expected = n.prime? (expected == actual) ? (num_correct += 1) : printf("%3d \e[31mreal:%-5s ai:%-5s\e[0m (%f)\n", n, expected, actual, result) end puts puts "SUMMARY: #{num_correct}/#{num_attempted} (#{100.0*num_correct/num_attempted}%)" if num_attempted == num_correct require "pry" binding.pry end puts "-----------------------------" end
 # require 'pp' # puts '[' << synapseses_rhs_major.map { |s| s.map { |s| '%4.1f' % s }.join(', ') }.join("],\n") # 8 bits # bits = [n[0], n[1], n[2], n[3], n[4], n[5], n[6], n[7]] # [[[-3.2, 3.6, -0.7, 2.8, 0.9, 3.2, -0.1, 0.4, -2.3], # [-0.2, -3.0, 0.9, -3.8, 1.3, 3.1, 1.0, -2.4, 0.2], # [-1.0, 1.4, -0.0, -3.5, 1.6, 2.4, -2.9, 4.0, -1.0], # [-0.6, 3.0, 1.9, 2.4, 0.4, 0.4, 2.7, 0.6, -0.9], # [-4.1, -1.6, 2.8, -1.6, 0.6, -1.3, 3.3, 1.2, -0.5], # [ 0.0, 3.8, -5.0, 4.1, -5.2, -1.1, -5.9, 4.6, 1.3], # [-0.2, 2.0, -1.9, -1.0, -2.5, 6.2, -2.6, -0.3, -1.3], # [-0.3, -0.7, 0.6, -0.8, 1.3, -0.6, 1.9, -2.2, -0.2], # [-2.3, -0.3, 3.4, -2.6, 2.6, -1.2, 2.5, -2.1, -3.2], # [ 0.8, -0.9, -4.4, 3.4, -1.6, 2.5, -0.6, -2.1, -0.9], # [ 3.2, 3.7, 1.6, -3.8, -1.7, 4.5, -4.6, 3.2, -0.3], # [-0.3, 0.7, 1.1, -0.3, -0.8, -0.5, -0.3, -0.8, -0.4], # [-1.5, 1.7, -2.5, 0.3, -1.3, 2.0, -2.5, 0.4, 0.3], # [ 2.9, -2.6, -1.7, 0.0, 1.4, -2.1, -1.7, 1.7, 0.9], # [ 0.1, -2.5, 3.2, -1.7, 1.4, -2.5, 3.1, -1.6, -0.1], # [-3.6, -2.1, -0.5, 4.1, -1.6, 0.6, -1.3, 4.4, -0.8], # ], # [[ 0.7, 1.1, -0.6, -0.6, 0.3, -0.7, 0.2, 1.0, -0.2, -0.7, -1.0, -0.9, 0.4, -0.9, -0.8, -0.5, -1.1], # [-0.6, -0.9, -0.3, -0.9, 0.2, -0.1, 0.0, 0.3, 0.8, -0.7, 0.5, -0.4, -0.3, -0.8, -0.3, -0.5, -0.4], # [-1.6, -0.5, -1.1, 0.2, 0.2, -1.5, 0.5, -0.7, 0.2, -0.9, 0.6, -0.3, 0.3, 0.0, 0.4, -0.1, -1.4], # [ 2.0, 3.3, -0.1, -1.3, 3.2, 7.0, -6.3, 0.5, 4.6, -3.8, -5.1, 1.3, -1.0, 1.5, 0.3, -0.7, -0.8], # [-1.5, 0.5, -0.5, -0.1, -0.9, -1.5, -0.6, 0.8, -0.6, -1.1, 1.3, -0.3, -0.2, 0.2, 0.8, -1.4, -0.5], # [-2.1, 1.2, 1.2, -3.7, 0.3, 2.0, -1.9, 1.1, 1.6, 0.8, -0.8, 0.1, -0.0, -0.3, 1.0, -0.4, -0.2], # [ 1.1, 2.4, 4.2, -1.1, 2.2, -7.7, 2.3, 1.1, 2.0, 3.5, -3.4, 0.5, 3.1, -3.7, 0.8, 5.4, 0.2], # [-5.1, 1.4, -1.5, 1.4, -0.6, -5.4, -2.1, 2.1, 1.1, -1.5, 1.0, 0.1, -2.6, 0.3, 5.1, -2.7, -1.4], # ], # [[-0.7, 0.2, 0.9, -12.0, 1.4, -4.3, -11.1, 9.8, 5.4]], # ] # 7 bits # bits = [n[0], n[1], n[2], n[3], n[4], n[5], n[6]] # [[[-5.8, 3.6, 1.9, 1.1, 3.5, 2.1, 3.3, -4.9], # [-3.0, 5.4, 6.8, 2.7, -5.6, 4.6, -6.4, -1.6], # [-5.1, 0.2, 4.5, 2.3, -1.9, -2.3, 0.0, -1.0], # [ 0.2, 2.2, -2.7, 4.8, 0.1, -3.0, -1.7, -0.7], # [-2.6, 4.3, -4.8, 3.6, -3.2, 3.3, -6.0, 2.6], # [-3.0, 2.5, 1.1, -4.1, 1.5, 6.5, -3.3, -4.1], # [-1.7, -5.8, 2.2, 6.2, -0.5, -6.8, 3.7, 2.4]], # [[-8.0, 9.6, -7.6, 8.6, -12.3, -9.4, -9.6, 5.4]], # ] # 8 bits # bits = [n[0], n[1], n[2], n[3], n[4], n[5], n[6], n[7]] # ~> NameError: undefined local variable or method `n' for main:Object # [[[-8.8, 3.0, 9.2, 5.0, -4.8, 5.6, -5.8, 6.1, -6.1], # [-1.8, 2.3, -2.6, -5.6, 0.1, 6.0, -4.7, -5.7, -3.0], # [-3.6, -4.1, 6.5, -0.6, -2.8, -2.6, 2.4, -2.3, 1.0], # [ 2.4, -9.1, -3.1, 7.8, -3.7, -8.9, -2.8, 5.6, 6.3], # [-2.0, 4.0, 11.0, 3.3, -6.0, 0.7, -7.0, 1.6, -3.0], # [ 0.7, -7.2, 2.8, 4.5, -3.6, -1.5, 2.7, -0.1, -3.5], # [-8.9, 5.5, 4.8, -4.1, 5.6, 4.8, 5.5, -4.0, -6.4], # [-4.0, -6.7, -3.6, 5.5, 2.5, -6.9, 7.8, -4.1, -0.4], # [ 1.1, -3.5, -6.7, -2.8, 2.3, -4.7, 6.3, -4.6, -2.2] # ], # [[-14.6, -13.9, -9.4, -10.8, 13.7, 10.6, -10.7, -11.0, 12.7, 5.2]], # ]
 def is_prime?(n) [[[-9, 3, 9, 5, -5, 6, -6, 6, -6], [-2, 2, -3, -6, 0, 6, -5, -6, -3], [-4, -4, 7, 0, -3, -3, 2, -2, 1], [ 2, -9, -3, 8, -4, -9, -3, 6, 6], [-2, 4, 11, 3, -6, 1, -7, 2, -3], [ 1, -7, 3, 4, -4, -2, 3, 0, -3.5], [-9, 6, 5, -4, 6, 5, 5, -4, -6], [-4, -7, -4, 5, 2, -7, 8, -4, 0], [ 1, -3, -7, -3, 2, -5, 6, -5, -2]], [[-15,-14,-9,-11,14,11,-11,-11,11,5]]] .inject(8.times.map{|i|n[i]}){|l,r|r.map{|s|1/(1+Math::E**[*l,1].zip(s).inject(0){|x,(n,s)|x-n*s})}}[0]>0.5 end require 'prime' (256).times.all? { |n| is_prime?(n) == n.prime? } # => true