Skip to content

Instantly share code, notes, and snippets.

@ikbendewilliam
Created February 17, 2021 10:46
Show Gist options
  • Save ikbendewilliam/3737982823f63b1ed58abbb7c02dfe96 to your computer and use it in GitHub Desktop.
Save ikbendewilliam/3737982823f63b1ed58abbb7c02dfe96 to your computer and use it in GitHub Desktop.
import tensorflow as tf
import pickle
import numpy as np
class RLModel:
def __init__(self, version = None, learning_rate = 1e-3):
if version is not None:
self.retrieveVariables(version)
else:
self.learning_rate = learning_rate
self.create_model(nodes)
def create_model(self):
self.model = tf.keras.Sequential([
tf.keras.layers.Input((4 * 4 * 3,)),
tf.keras.layers.Dense(16, activation='relu'),
tf.keras.layers.Dense(4)
])
self.model.compile(loss='mse', optimizer=tf.optimizers.Adam(self.learning_rate))
def predict(self, board):
return self.model.predict(np.array(board).reshape(1, -1))
def retrieveVariables(self, version):
self.model = tf.keras.models.load_model('{}.h5'.format(version))
def saveVariables(self):
self.model.save('newModel.h5')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment