Skip to content

Instantly share code, notes, and snippets.

@Dviejopomata Dviejopomata/rl.py
Created May 4, 2019

Embed
What would you like to do?
import random
import gym
import numpy as np
from collections import deque
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import Adam
from keras.models import load_model
# from scores.score_logger import ScoreLogger
ENV_NAME = "CartPole-v1"
GAMMA = 0.95
LEARNING_RATE = 0.001
MEMORY_SIZE = 1000000
BATCH_SIZE = 20
EXPLORATION_MAX = 1.0
EXPLORATION_MIN = 0.01
EXPLORATION_DECAY = 0.995
class DQNSolver:
def __init__(self, observation_space, action_space):
self.exploration_rate = EXPLORATION_MAX
self.action_space = action_space
self.memory = deque(maxlen=MEMORY_SIZE)
self.model = Sequential()
self.model.add(Dense(24, input_shape=(observation_space,), activation="relu"))
self.model.add(Dense(24, activation="relu"))
self.model.add(Dense(self.action_space, activation="linear"))
self.model.compile(loss="mse", optimizer=Adam(lr=LEARNING_RATE))
def remember(self, state, action, reward, next_state, done):
self.memory.append((state, action, reward, next_state, done))
def act(self, state):
if np.random.rand() < self.exploration_rate:
return random.randrange(self.action_space)
q_values = self.model.predict(state)
return np.argmax(q_values[0])
def load_model(self):
self.model = load_model("./model")
def experience_replay(self):
if len(self.memory) < BATCH_SIZE:
return
batch = random.sample(self.memory, BATCH_SIZE)
for state, action, reward, state_next, terminal in batch:
q_update = reward
if not terminal:
q_update = (reward + GAMMA * np.amax(self.model.predict(state_next)[0]))
q_values = self.model.predict(state)
q_values[0][action] = q_update
self.model.fit(state, q_values, verbose=0)
self.exploration_rate *= EXPLORATION_DECAY
self.exploration_rate = max(EXPLORATION_MIN, self.exploration_rate)
self.model.save("./model", True, True)
LOAD_MODEL = True
def cartpole():
env = gym.make(ENV_NAME)
# score_logger = ScoreLogger(ENV_NAME)
observation_space = env.observation_space.shape[0]
action_space = env.action_space.n
dqn_solver = DQNSolver(observation_space, action_space)
if LOAD_MODEL:
dqn_solver.load_model()
run = 0
while True:
run += 1
state = env.reset()
state = np.reshape(state, [1, observation_space])
step = 0
while True:
step += 1
# env.render()
action = dqn_solver.act(state)
state_next, reward, terminal, info = env.step(action)
reward = reward if not terminal else -reward
state_next = np.reshape(state_next, [1, observation_space])
dqn_solver.remember(state, action, reward, state_next, terminal)
state = state_next
if terminal:
print(
"Run: " + str(run) + ", exploration: " + str(dqn_solver.exploration_rate) + ", score: " + str(step))
# score_logger.add_score(step, run)
break
dqn_solver.experience_replay()
if __name__ == "__main__":
cartpole()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.