Skip to content

Instantly share code, notes, and snippets.

@zubair-irshad
Created November 9, 2020 20:48
Show Gist options
  • Save zubair-irshad/21b0be90645224fe0703eebd2326092a to your computer and use it in GitHub Desktop.
Save zubair-irshad/21b0be90645224fe0703eebd2326092a to your computer and use it in GitHub Desktop.
def train_epoch(self, diter, length, batch_size, epoch, writer, train_steps):
loss, action_loss, aux_loss = 0, 0, 0
step_id = 0
# high_level_losses=[]
# low_level_action_losses =[]
# low_level_stop_losses =[]
# low_level_total_losses=[]
if self.config.DAGGER.INTER_MODULE_ATTN:
self.actor_critic.train()
else:
self.high_level.train()
self.low_level.train()
for batch in tqdm.tqdm(
diter, total=length // batch_size, leave=False
):
( observations_batch,
prev_actions_batch,
not_done_masks,
corrected_actions_batch,
oracle_stop_batch
) = batch
if self.config.DAGGER.INTER_MODULE_ATTN:
high_recurrent_hidden_states = torch.zeros(
self.actor_critic.num_recurrent_layers,
self.config.DAGGER.BATCH_SIZE,
self.config.MODEL.STATE_ENCODER.hidden_size,
device=self.device,
)
low_recurrent_hidden_states = torch.zeros(
self.actor_critic.num_recurrent_layers,
self.config.DAGGER.BATCH_SIZE,
self.config.MODEL.STATE_ENCODER.hidden_size,
device=self.device2,
)
else:
high_recurrent_hidden_states = torch.zeros(
self.high_level.state_encoder.num_recurrent_layers,
self.config.DAGGER.BATCH_SIZE,
self.config.MODEL.STATE_ENCODER.hidden_size,
device=self.device,
)
low_recurrent_hidden_states = torch.zeros(
self.low_level.state_encoder.num_recurrent_layers,
self.config.DAGGER.BATCH_SIZE,
self.config.MODEL.STATE_ENCODER.hidden_size,
device=self.device2,
)
detached_state_low = None
batch_split = split_batch_tbptt(observations_batch, prev_actions_batch, not_done_masks,
corrected_actions_batch, oracle_stop_batch, self.config.DAGGER.tbptt_steps,
self.config.DAGGER.split_dim)
del observations_batch, prev_actions_batch, not_done_masks, corrected_actions_batch, batch
for split in batch_split:
( observations_batch,
prev_actions_batch,
not_done_masks,
corrected_actions_batch,
oracle_stop_batch
) = split
observations_batch = {
k: v.to(device=self.device, non_blocking=True)
for k, v in observations_batch.items()
}
try:
loss, high_recurrent_hidden_states, low_recurrent_hidden_states, detached_state_low= self._update_agent(
observations_batch,
prev_actions_batch.to(
device=self.device, non_blocking=True
),
not_done_masks.to(
device=self.device, non_blocking=True
),
corrected_actions_batch.to(
device=self.device, non_blocking=True
),
oracle_stop_batch.to(
device=self.device, non_blocking=True
),
high_recurrent_hidden_states,
low_recurrent_hidden_states,
detached_state_low
)
writer.add_scalar(f"Train High Level Action Loss", loss[0], train_steps)
writer.add_scalar(f"Train Low Level Action Loss", loss[1], train_steps)
writer.add_scalar(f"Train Low Level Stop Loss", loss[2], train_steps)
writer.add_scalar(f"Train Low_level Total Loss", loss[1]+loss[2], train_steps)
train_steps += 1
self.save_checkpoint(
f"ckpt.{self.config.DAGGER.EPOCHS + epoch}.pth"
)
return train_steps
def val_epoch(self, diter, length, batch_size, epoch, writer, val_steps):
loss, aux_loss = 0, 0
step_id = 0
# high_level_losses = []
# low_level_total_losses = []
val_high_losses = []
val_low_losses = []
if self.config.DAGGER.INTER_MODULE_ATTN:
self.actor_critic.eval()
else:
self.high_level.eval()
self.low_level.eval()
correct_labels = 0
total_correct=0
with torch.no_grad():
for batch in tqdm.tqdm(
diter, total=length // batch_size, leave=False
):
( observations_batch,
prev_actions_batch,
not_done_masks,
corrected_actions_batch,
oracle_stop_batch
) = batch
if self.config.DAGGER.INTER_MODULE_ATTN:
high_recurrent_hidden_states = torch.zeros(
self.actor_critic.num_recurrent_layers,
self.config.DAGGER.BATCH_SIZE,
self.config.MODEL.STATE_ENCODER.hidden_size,
device=self.device,
)
low_recurrent_hidden_states = torch.zeros(
self.actor_critic.num_recurrent_layers,
self.config.DAGGER.BATCH_SIZE,
self.config.MODEL.STATE_ENCODER.hidden_size,
device=self.device2,
)
else:
high_recurrent_hidden_states = torch.zeros(
self.high_level.state_encoder.num_recurrent_layers,
self.config.DAGGER.BATCH_SIZE,
self.config.MODEL.STATE_ENCODER.hidden_size,
device=self.device,
)
low_recurrent_hidden_states = torch.zeros(
self.low_level.state_encoder.num_recurrent_layers,
self.config.DAGGER.BATCH_SIZE,
self.config.MODEL.STATE_ENCODER.hidden_size,
device=self.device2,
)
detached_state_low = None
batch_split = split_batch_tbptt(observations_batch, prev_actions_batch, not_done_masks,
corrected_actions_batch, oracle_stop_batch, self.config.DAGGER.tbptt_steps,
self.config.DAGGER.split_dim)
del observations_batch, prev_actions_batch, not_done_masks, corrected_actions_batch, batch
for split in batch_split:
( observations_batch,
prev_actions_batch,
not_done_masks,
corrected_actions_batch,
oracle_stop_batch
) = split
observations_batch = {
k: v.to(device=self.device, non_blocking=True)
for k, v in observations_batch.items()
}
loss, high_recurrent_hidden_states, low_recurrent_hidden_states, detached_state_low, correct, total= self._update_agent_val(
observations_batch,
prev_actions_batch.to(
device=self.device, non_blocking=True
),
not_done_masks.to(
device=self.device, non_blocking=True
),
corrected_actions_batch.to(
device=self.device, non_blocking=True
),
oracle_stop_batch.to(
device=self.device, non_blocking=True
),
high_recurrent_hidden_states,
low_recurrent_hidden_states,
detached_state_low
)
correct_labels+= correct
total_correct+=total
writer.add_scalar(f"Val High Level Action Loss", loss[0], val_steps)
writer.add_scalar(f"Val Low_level Total Loss", loss[1]+loss[2], val_steps)
val_steps += 1
val_low_losses.append(loss[0])
val_high_losses.append(loss[1]+loss[2])
final_accuracy = 100 * correct_labels / total_correct
writer.add_scalar(f"Val High level Loss epoch", np.mean(val_high_losses), epoch)
writer.add_scalar(f"Val Low level Loss epoch", np.mean(val_low_losses), epoch)
writer.add_scalar(f"Validation Accuracy", final_accuracy, epoch)
return val_steps
def _update_agent(
self, observations, prev_actions, not_done_masks, corrected_actions, oracle_stop, high_recurrent_hidden_states,
low_recurrent_hidden_states, detached_state_low
):
self.optimizer_high_level.zero_grad()
self.optimizer_low_level.zero_grad()
high_level_criterion = nn.CrossEntropyLoss(ignore_index=-1, reduction="mean")
low_level_criterion = nn.MSELoss()
low_level_stop_criterion = nn.BCEWithLogitsLoss()
AuxLosses.clear()
high_recurrent_hidden_states = repackage_hidden(high_recurrent_hidden_states)
low_recurrent_hidden_states = repackage_hidden(low_recurrent_hidden_states)
if self.config.DAGGER.INTER_MODULE_ATTN:
high_level_action_mask = observations['vln_oracle_action_sensor'] ==0
observations['vln_oracle_action_sensor'] = observations['vln_oracle_action_sensor'].to(dtype=torch.int64)
discrete_actions = observations['vln_oracle_action_sensor']
discrete_action_mask = discrete_actions ==0
discrete_actions = (discrete_actions-1).masked_fill_(discrete_action_mask, 4)
batch = (observations, high_recurrent_hidden_states, low_recurrent_hidden_states,
prev_actions, not_done_masks, discrete_actions.view(-1), detached_state_low)
high_output, low_output, low_stop_output, high_recurrent_hidden_states, low_recurrent_hidden_states, detached_state_low = self.actor_critic(batch)
del batch
high_output = high_output.masked_fill_(high_level_action_mask, 0)
high_level_loss = high_level_criterion(high_output,(observations['vln_oracle_action_sensor']-1).squeeze(1))
high_level_loss.backward()
high_level_loss_data = high_level_loss.detach()
del high_output
self.optimizer_high_level.step()
oracle_stop = oracle_stop.to(self.device2)
corrected_actions = corrected_actions.to(self.device2)
action_mask = corrected_actions==0
low_output = low_output.masked_fill_(action_mask, 0)
low_output = low_output.to(dtype=torch.float)
corrected_actions = corrected_actions.to(dtype=torch.float)
low_level_action_loss = low_level_criterion(low_output, corrected_actions)
mask = (oracle_stop!=-1)
oracle_stop = torch.masked_select(oracle_stop, mask)
low_stop_output = torch.masked_select(low_stop_output, mask)
low_level_stop_loss = low_level_stop_criterion(low_stop_output, oracle_stop)
low_level_loss = low_level_action_loss + low_level_stop_loss
low_level_loss.backward()
self.optimizer_low_level.step()
else:
batch = (observations, high_recurrent_hidden_states, prev_actions, not_done_masks)
output, high_recurrent_hidden_states = self.high_level(batch)
del batch
high_level_action_mask = observations['vln_oracle_action_sensor'] ==0
output = output.masked_fill_(high_level_action_mask, 0)
observations['vln_oracle_action_sensor'] = observations['vln_oracle_action_sensor'].squeeze(1).to(dtype=torch.int64)
high_level_loss = high_level_criterion(output,(observations['vln_oracle_action_sensor']-1))
high_level_loss.backward()
self.optimizer_high_level.step()
high_level_loss_data = high_level_loss.detach()
del output
self.low_level.to(self.device2)
observations = {
k: v.to(device=self.device2, non_blocking=True)
for k, v in observations.items()
}
discrete_actions = observations['vln_oracle_action_sensor']
discrete_action_mask = discrete_actions ==0
discrete_actions = (discrete_actions-1).masked_fill_(discrete_action_mask, 4)
del observations['vln_oracle_action_sensor']
batch = (observations,
low_recurrent_hidden_states,
prev_actions.to(
device=self.device2, non_blocking=True
),
not_done_masks.to(
device=self.device2, non_blocking=True
),
discrete_actions.view(-1))
del observations, prev_actions, not_done_masks
oracle_stop = oracle_stop.to(self.device2)
output, stop_out, low_recurrent_hidden_states = self.low_level(batch)
corrected_actions = corrected_actions.to(self.device2)
action_mask = corrected_actions==0
output = output.masked_fill_(action_mask, 0)
output = output.to(dtype=torch.float)
corrected_actions = corrected_actions.to(dtype=torch.float)
low_level_action_loss = low_level_criterion(output, corrected_actions)
mask = (oracle_stop!=-1)
oracle_stop = torch.masked_select(oracle_stop, mask)
stop_out = torch.masked_select(stop_out, mask)
low_level_stop_loss = low_level_stop_criterion(stop_out, oracle_stop)
low_level_loss = low_level_action_loss + low_level_stop_loss
low_level_loss.backward()
self.optimizer_low_level.step()
aux_loss_data =0
loss = (high_level_loss_data.item(), low_level_action_loss.detach().item(),
low_level_stop_loss.detach().item(), aux_loss_data)
return loss, high_recurrent_hidden_states, low_recurrent_hidden_states, detached_state_low
def _update_agent_val(
self, observations, prev_actions, not_done_masks, corrected_actions, oracle_stop, high_recurrent_hidden_states,
low_recurrent_hidden_states, detached_state_low
):
high_level_criterion = nn.CrossEntropyLoss(ignore_index=-1, reduction="mean")
low_level_criterion = nn.MSELoss()
low_level_stop_criterion = nn.BCEWithLogitsLoss()
AuxLosses.clear()
high_recurrent_hidden_states = repackage_hidden(high_recurrent_hidden_states)
low_recurrent_hidden_states = repackage_hidden(low_recurrent_hidden_states)
if self.config.DAGGER.INTER_MODULE_ATTN:
high_level_action_mask = observations['vln_oracle_action_sensor'] ==0
observations['vln_oracle_action_sensor'] = observations['vln_oracle_action_sensor'].to(dtype=torch.int64)
discrete_actions = observations['vln_oracle_action_sensor']
discrete_action_mask = discrete_actions ==0
discrete_actions = (discrete_actions-1).masked_fill_(discrete_action_mask, 4)
batch = (observations, high_recurrent_hidden_states, low_recurrent_hidden_states,
prev_actions, not_done_masks, discrete_actions.view(-1), detached_state_low)
high_output, low_output, low_stop_output, high_recurrent_hidden_states, low_recurrent_hidden_states, detached_state_low = self.actor_critic(batch)
del batch
high_output = high_output.masked_fill_(high_level_action_mask, 0)
high_level_loss = high_level_criterion(high_output,(observations['vln_oracle_action_sensor']-1).squeeze(1))
oracle_stop = oracle_stop.to(self.device2)
corrected_actions = corrected_actions.to(self.device2)
action_mask = corrected_actions==0
low_output = low_output.masked_fill_(action_mask, 0)
low_output = low_output.to(dtype=torch.float)
corrected_actions = corrected_actions.to(dtype=torch.float)
low_level_action_loss = low_level_criterion(low_output, corrected_actions)
mask = (oracle_stop!=-1)
oracle_stop = torch.masked_select(oracle_stop, mask)
low_stop_output = torch.masked_select(low_stop_output, mask)
low_level_stop_loss = low_level_stop_criterion(low_stop_output, oracle_stop)
else:
batch = (observations, high_recurrent_hidden_states, prev_actions, not_done_masks)
output, high_recurrent_hidden_states = self.high_level(batch)
del batch
high_level_action_mask = observations['vln_oracle_action_sensor'] ==0
output = output.masked_fill_(high_level_action_mask, 0)
observations['vln_oracle_action_sensor'] = observations['vln_oracle_action_sensor'].squeeze(1).to(dtype=torch.int64)
high_level_loss = high_level_criterion(output,(observations['vln_oracle_action_sensor']-1))
predicted = torch.argmax(output, dim=1)
corrected_mask = ~high_level_action_mask
correct = torch.masked_select((observations['vln_oracle_action_sensor']-1), corrected_mask)
predicted = torch.masked_select(predicted, corrected_mask)
output = output.masked_fill_(high_level_action_mask, 0)
accuracy = (predicted == correct).sum().item()
total = predicted.size(0)
del output
self.low_level.to(self.device2)
observations = {
k: v.to(device=self.device2, non_blocking=True)
for k, v in observations.items()
}
batch = (observations,
low_recurrent_hidden_states,
prev_actions.to(
device=self.device2, non_blocking=True
),
not_done_masks.to(
device=self.device2, non_blocking=True
),
observations['vln_oracle_action_sensor']-1)
del observations, prev_actions, not_done_masks
oracle_stop = oracle_stop.to(self.device2)
output, stop_out, low_recurrent_hidden_states = self.low_level(batch)
corrected_actions = corrected_actions.to(self.device2)
action_mask = corrected_actions==0
output = output.masked_fill_(action_mask, 0)
output = output.to(dtype=torch.float)
corrected_actions = corrected_actions.to(dtype=torch.float)
low_level_action_loss = low_level_criterion(output, corrected_actions)
mask = (oracle_stop!=-1)
oracle_stop = torch.masked_select(oracle_stop, mask)
stop_out = torch.masked_select(stop_out, mask)
low_level_stop_loss = low_level_stop_criterion(stop_out, oracle_stop)
aux_loss_data =0
# low_level_loss_data = low_level_loss.detach()
loss = (high_level_loss.item(), low_level_action_loss.item(),
low_level_stop_loss.item(), aux_loss_data)
return loss, high_recurrent_hidden_states, low_recurrent_hidden_states, detached_state_low, accuracy, total
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment