Skip to content

Instantly share code, notes, and snippets.

@akirasosa
Last active January 3, 2023 16:37
Show Gist options
  • Save akirasosa/1aab46d260b927e99e97f65916195e38 to your computer and use it in GitHub Desktop.
Save akirasosa/1aab46d260b927e99e97f65916195e38 to your computer and use it in GitHub Desktop.
Chalenging to find prime numbers by Keras...
import numpy as np
from keras.layers import Dense, Dropout, Activation
from keras.layers.advanced_activations import PReLU
from keras.models import Sequential
from matplotlib import pyplot as plt
seed = 7
np.random.seed(seed)
num_digits = 14 # binary encode numbers
max_number = 2 ** num_digits
def prime_list():
counter = 0
primes = [2, 3]
for n in range(5, max_number, 2):
is_prime = True
for i in range(1, len(primes)):
counter += 1
if primes[i] ** 2 > n:
break
counter += 1
if n % primes[i] == 0:
is_prime = False
break
if is_prime:
primes.append(n)
return primes
primes = prime_list()
def prime_encode(i):
if i in primes:
return 1
else:
return 0
def bin_encode(i):
return [i >> d & 1 for d in range(num_digits)]
def create_dataset():
x, y = [], []
for i in range(102, max_number):
x.append(bin_encode(i))
y.append(prime_encode(i))
return np.array(x), y
x_train, y_train = create_dataset()
model = Sequential()
model.add(Dense(units=100, input_dim=num_digits))
model.add(PReLU())
model.add(Dropout(rate=0.2))
model.add(Dense(units=50))
model.add(PReLU())
model.add(Dropout(rate=0.2))
model.add(Dense(units=25))
model.add(PReLU())
model.add(Dropout(rate=0.2))
model.add(Dense(units=1))
model.add(Activation("sigmoid"))
model.compile(optimizer='RMSprop',
loss='binary_crossentropy',
metrics=['accuracy'])
history = model.fit(x_train, y_train, epochs=1000, batch_size=128,
validation_split=0.1)
# predict
errors, correct = 0, 0
tp, fn, fp = 0, 0, 0
for i in range(2, 101):
x = bin_encode(i)
y = model.predict(np.array(x).reshape(-1, num_digits))
if y[0][0] >= 0.5:
pred = 1
else:
pred = 0
obs = prime_encode(i)
print(i, obs, pred, y[0][0])
if pred == obs:
correct += 1
else:
errors += 1
if obs == 1 and pred == 1:
tp += 1
if obs == 1 and pred == 0:
fn += 1
if obs == 0 and pred == 1:
fp += 1
precision = tp / (tp + fp)
recall = tp / (tp + fn)
f_score = 2 * precision * recall / (precision + recall)
print("Errors :", errors, " Correct :", correct, "F score :", f_score)
def plot_history(history):
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend(['loss', 'val_loss'], loc='upper right')
plt.savefig('RMSprop_more')
plot_history(history)
@akirasosa
Copy link
Author

akirasosa commented Apr 14, 2017

Here is a result.

Errors : 9  Correct : 90 F score : 0.8235294117647058

rmsprop_more

2	1	0	3.60083e-08
3	1	0	0.0372073
4	0	0	8.21077e-06
5	1	1	0.617928
6	0	0	3.97566e-15
7	1	1	0.826378
8	0	0	0.000872908
9	0	0	0.0204348
10	0	0	1.00783e-09
11	1	0	0.187609
12	0	0	5.53888e-08
13	1	1	0.796401
14	0	0	3.67647e-15
15	0	0	0.029023
16	0	0	4.0449e-07
17	1	1	0.843016
18	0	0	4.10257e-14
19	1	1	0.917604
20	0	0	4.3736e-15
21	0	0	0.0141451
22	0	0	6.49346e-25
23	1	1	0.593597
24	0	0	4.45055e-08
25	0	1	0.735933
26	0	0	4.97732e-18
27	0	0	0.0958722
28	0	0	1.80154e-16
29	1	1	0.722513
30	0	0	2.00777e-28
31	1	1	0.774054
32	0	0	3.93779e-05
33	0	0	0.118341
34	0	0	1.88295e-11
35	0	0	0.480108
36	0	0	3.0609e-07
37	1	1	0.847888
38	0	0	3.42833e-18
39	0	0	0.0514646
40	0	0	5.82673e-07
41	1	1	0.726771
42	0	0	3.72693e-11
43	1	1	0.861872
44	0	0	5.71867e-14
45	0	0	0.18657
46	0	0	7.03075e-16
47	1	1	0.654062
48	0	0	1.30385e-10
49	0	1	0.923631
50	0	0	1.30955e-17
51	0	0	0.190215
52	0	0	6.45953e-19
53	1	1	0.558284
54	0	0	1.83163e-29
55	0	0	0.287756
56	0	0	3.29105e-11
57	0	0	0.292637
58	0	0	3.57044e-23
59	1	0	0.152102
60	0	0	1.80104e-22
61	1	1	0.858877
62	0	0	1.92684e-32
63	0	0	0.27367
64	0	0	1.74397e-09
65	0	1	0.727574
66	0	0	1.33752e-20
67	1	1	0.891129
68	0	0	1.47396e-17
69	0	0	0.346057
70	0	0	5.27672e-27
71	1	1	0.932053
72	0	0	4.04155e-10
73	1	1	0.879374
74	0	0	1.4077e-18
75	0	0	0.0290487
76	0	0	6.39801e-17
77	0	1	0.629597
78	0	0	1.54139e-30
79	1	1	0.791511
80	0	0	7.56631e-21
81	0	0	0.0438443
82	0	0	4.24787e-30
83	1	1	0.596353
84	0	0	6.45592e-32
85	0	0	0.431211
86	0	0	0.0
87	0	0	0.00903795
88	0	0	9.54647e-23
89	1	1	0.827787
90	0	0	2.43897e-31
91	0	1	0.746695
92	0	0	8.37092e-37
93	0	0	0.0384408
94	0	0	0.0
95	0	0	0.3743
96	0	0	7.28071e-13
97	1	1	0.888417
98	0	0	3.04541e-25
99	0	0	0.0649973
100	0	0	1.59478e-18

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment