import tensorflow as tf | |
import pickle | |
import numpy as np | |
class RLModel: | |
def __init__(self, version = None, learning_rate = 1e-3): | |
if version is not None: | |
self.retrieveVariables(version) | |
else: | |
self.learning_rate = learning_rate | |
self.create_model(nodes) | |
def create_model(self): | |
self.model = tf.keras.Sequential([ | |
tf.keras.layers.Input((4 * 4 * 3,)), | |
tf.keras.layers.Dense(16, activation='relu'), | |
tf.keras.layers.Dense(4) | |
]) | |
self.model.compile(loss='mse', optimizer=tf.optimizers.Adam(self.learning_rate)) | |
def predict(self, board): | |
return self.model.predict(np.array(board).reshape(1, -1)) | |
def retrieveVariables(self, version): | |
self.model = tf.keras.models.load_model('{}.h5'.format(version)) | |
def saveVariables(self): | |
self.model.save('newModel.h5') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment