Created
November 6, 2019 04:41
-
-
Save jadechip/0f9c6387ed0af4c725bbdc8095ffa114 to your computer and use it in GitHub Desktop.
Reinforcement learning - Navigation companion code
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Agent(): | |
def __init__(self, state_size, action_size, seed): | |
self.state_size = state_size | |
self.action_size = action_size | |
self.seed = random.seed(seed) | |
# Q-Network | |
self.qnetwork_local = QNetwork(state_size, action_size, seed).to(device) | |
self.qnetwork_target = QNetwork(state_size, action_size, seed).to(device) | |
self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR) | |
# Replay memory | |
self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed) | |
# Initialize time step (for updating every UPDATE_EVERY steps) | |
self.t_step = 0 | |
def step(self, state, action, reward, next_state, done): | |
# Save experience in replay memory | |
self.memory.add(state, action, reward, next_state, done) | |
# Learn every UPDATE_EVERY time steps. | |
self.t_step = (self.t_step + 1) % UPDATE_EVERY | |
if self.t_step == 0: | |
# If enough samples are available in memory, get random subset and learn | |
if len(self.memory) > BATCH_SIZE: | |
experiences = self.memory.sample() | |
self.learn(experiences, GAMMA) | |
def act(self, state, eps=0.): | |
state = torch.from_numpy(state).float().unsqueeze(0).to(device) | |
self.qnetwork_local.eval() | |
with torch.no_grad(): | |
action_values = self.qnetwork_local(state) | |
self.qnetwork_local.train() | |
# Epsilon-greedy action selection | |
if random.random() > eps: | |
return np.argmax(action_values.cpu().data.numpy()) | |
else: | |
return random.choice(np.arange(self.action_size)) | |
def learn(self, experiences, gamma): | |
states, actions, rewards, next_states, dones = experiences | |
sample_exp = zip(states, actions, rewards, next_states, dones) | |
# DQN | |
# Get max predicted Q values (for next states) from target (frozen) model | |
# Q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(1) | |
# DDQN | |
# Choose actions using the local network and evaluate using the target network | |
next_actions = self.qnetwork_local(next_states).detach().argmax(1).unsqueeze(1) | |
Q_targets_next = self.qnetwork_target(next_states).detach().gather(1, next_actions) | |
# Compute Q targets for current states | |
Q_targets = rewards + (gamma * Q_targets_next * (1 - dones)) | |
# Get expected Q values from local model | |
Q_expected = self.qnetwork_local(states).gather(1, actions) | |
# Compute loss | |
loss = F.mse_loss(Q_expected, Q_targets) | |
# Minimize the loss | |
self.optimizer.zero_grad() | |
loss.backward() | |
self.optimizer.step() | |
# ------------------- update target network ------------------- # | |
self.soft_update(self.qnetwork_local, self.qnetwork_target, TAU) | |
def soft_update(self, local_model, target_model, tau): | |
for target_param, local_param in zip(target_model.parameters(), local_model.parameters()): | |
target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data) | |
class ReplayBuffer: | |
def __init__(self, action_size, buffer_size, batch_size, seed): | |
self.action_size = action_size | |
self.memory = deque(maxlen=buffer_size) | |
self.batch_size = batch_size | |
self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"]) | |
self.seed = random.seed(seed) | |
def add(self, state, action, reward, next_state, done): | |
e = self.experience(state, action, reward, next_state, done) | |
self.memory.append(e) | |
def sample(self): | |
experiences = random.sample(self.memory, k=self.batch_size) | |
states = torch.from_numpy(np.vstack([e.state for e in experiences if e is not None])).float().to(device) | |
actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(device) | |
rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(device) | |
next_states = torch.from_numpy(np.vstack([e.next_state for e in experiences if e is not None])).float().to(device) | |
dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(device) | |
return (states, actions, rewards, next_states, dones) | |
def __len__(self): | |
return len(self.memory) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment