Skip to content

Instantly share code, notes, and snippets.

@kiki67100
Forked from rosdyana/simple_keras.py
Created December 20, 2017 06:26
Show Gist options
  • Save kiki67100/b11cc660536f1d2bfa376300ec7303f5 to your computer and use it in GitHub Desktop.
Save kiki67100/b11cc660536f1d2bfa376300ec7303f5 to your computer and use it in GitHub Desktop.
simple keras
import numpy as np
from keras.models import Sequential
from keras.layers import Dense
# dataset
trainX = np.array([1, 2 ,3 ,4 , 5 , 6 , 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24])
trainY = np.array([3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48, 51, 54, 57, 60, 63, 66, 69, 72])
# create a model
model = Sequential()
model.add(Dense(8, input_dim=1, activation='relu'))
model.add(Dense(1))
model.compile(loss='mean_squared_error', optimizer='adam')
model.fit(trainX, trainY, nb_epoch=1500, batch_size=2, verbose=2)
# try to predict
dataPrediction = model.predict(np.array([64]))
print (int(dataPrediction[0][0])), '<--- Predicted number'
print (64*3,' <-- Correct answer \n')
dataPrediction = model.predict(np.array([56]))
print (int(dataPrediction[0][0])), '<--- Predicted number'
print (56*3,' <-- Correct answer \n')
dataPrediction = model.predict(np.array([345]))
print (int(dataPrediction[0][0])), '<--- Predicted number'
print (345*3,' <-- Correct answer \n')
dataPrediction = model.predict(np.array([456]))
print (int(dataPrediction[0][0])), '<--- Predicted number'
print (456*3,' <-- Correct answer \n')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment