Skip to content

Instantly share code, notes, and snippets.

@NMZivkovic
Created July 7, 2019 11:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save NMZivkovic/286e65cab6e66bd93884732a82966ea4 to your computer and use it in GitHub Desktop.
Save NMZivkovic/286e65cab6e66bd93884732a82966ea4 to your computer and use it in GitHub Desktop.
class Agent:
def __init__(self, enviroment, optimizer):
# Initialize atributes
self._state_size = enviroment.observation_space.n
self._action_size = enviroment.action_space.n
self._optimizer = optimizer
self.expirience_replay = deque(maxlen=2000)
# Initialize discount and exploration rate
self.gamma = 0.6
self.epsilon = 0.1
# Build networks
self.q_network = self._build_compile_model()
self.target_network = self._build_compile_model()
self.alighn_target_model()
def store(self, state, action, reward, next_state, terminated):
self.expirience_replay.append((state, action, reward, next_state, terminated))
def _build_compile_model(self):
model = Sequential()
model.add(Embedding(self._state_size, 10, input_length=1))
model.add(Reshape((10,)))
model.add(Dense(50, activation='relu'))
model.add(Dense(50, activation='relu'))
model.add(Dense(self._action_size, activation='linear'))
model.compile(loss='mse', optimizer=self._optimizer)
return model
def alighn_target_model(self):
self.target_network.set_weights(self.q_network.get_weights())
def act(self, state):
if np.random.rand() <= self.epsilon:
return enviroment.action_space.sample()
q_values = self.q_network.predict(state)
return np.argmax(q_values[0])
def retrain(self, batch_size):
minibatch = random.sample(self.expirience_replay, batch_size)
for state, action, reward, next_state, terminated in minibatch:
target = self.q_network.predict(state)
if terminated:
target[0][action] = reward
else:
t = self.target_network.predict(next_state)
target[0][action] = reward + self.gamma * np.amax(t)
self.q_network.fit(state, target, epochs=1, verbose=0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment