Created
November 9, 2020 20:48
-
-
Save zubair-irshad/21b0be90645224fe0703eebd2326092a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def train_epoch(self, diter, length, batch_size, epoch, writer, train_steps): | |
loss, action_loss, aux_loss = 0, 0, 0 | |
step_id = 0 | |
# high_level_losses=[] | |
# low_level_action_losses =[] | |
# low_level_stop_losses =[] | |
# low_level_total_losses=[] | |
if self.config.DAGGER.INTER_MODULE_ATTN: | |
self.actor_critic.train() | |
else: | |
self.high_level.train() | |
self.low_level.train() | |
for batch in tqdm.tqdm( | |
diter, total=length // batch_size, leave=False | |
): | |
( observations_batch, | |
prev_actions_batch, | |
not_done_masks, | |
corrected_actions_batch, | |
oracle_stop_batch | |
) = batch | |
if self.config.DAGGER.INTER_MODULE_ATTN: | |
high_recurrent_hidden_states = torch.zeros( | |
self.actor_critic.num_recurrent_layers, | |
self.config.DAGGER.BATCH_SIZE, | |
self.config.MODEL.STATE_ENCODER.hidden_size, | |
device=self.device, | |
) | |
low_recurrent_hidden_states = torch.zeros( | |
self.actor_critic.num_recurrent_layers, | |
self.config.DAGGER.BATCH_SIZE, | |
self.config.MODEL.STATE_ENCODER.hidden_size, | |
device=self.device2, | |
) | |
else: | |
high_recurrent_hidden_states = torch.zeros( | |
self.high_level.state_encoder.num_recurrent_layers, | |
self.config.DAGGER.BATCH_SIZE, | |
self.config.MODEL.STATE_ENCODER.hidden_size, | |
device=self.device, | |
) | |
low_recurrent_hidden_states = torch.zeros( | |
self.low_level.state_encoder.num_recurrent_layers, | |
self.config.DAGGER.BATCH_SIZE, | |
self.config.MODEL.STATE_ENCODER.hidden_size, | |
device=self.device2, | |
) | |
detached_state_low = None | |
batch_split = split_batch_tbptt(observations_batch, prev_actions_batch, not_done_masks, | |
corrected_actions_batch, oracle_stop_batch, self.config.DAGGER.tbptt_steps, | |
self.config.DAGGER.split_dim) | |
del observations_batch, prev_actions_batch, not_done_masks, corrected_actions_batch, batch | |
for split in batch_split: | |
( observations_batch, | |
prev_actions_batch, | |
not_done_masks, | |
corrected_actions_batch, | |
oracle_stop_batch | |
) = split | |
observations_batch = { | |
k: v.to(device=self.device, non_blocking=True) | |
for k, v in observations_batch.items() | |
} | |
try: | |
loss, high_recurrent_hidden_states, low_recurrent_hidden_states, detached_state_low= self._update_agent( | |
observations_batch, | |
prev_actions_batch.to( | |
device=self.device, non_blocking=True | |
), | |
not_done_masks.to( | |
device=self.device, non_blocking=True | |
), | |
corrected_actions_batch.to( | |
device=self.device, non_blocking=True | |
), | |
oracle_stop_batch.to( | |
device=self.device, non_blocking=True | |
), | |
high_recurrent_hidden_states, | |
low_recurrent_hidden_states, | |
detached_state_low | |
) | |
writer.add_scalar(f"Train High Level Action Loss", loss[0], train_steps) | |
writer.add_scalar(f"Train Low Level Action Loss", loss[1], train_steps) | |
writer.add_scalar(f"Train Low Level Stop Loss", loss[2], train_steps) | |
writer.add_scalar(f"Train Low_level Total Loss", loss[1]+loss[2], train_steps) | |
train_steps += 1 | |
self.save_checkpoint( | |
f"ckpt.{self.config.DAGGER.EPOCHS + epoch}.pth" | |
) | |
return train_steps | |
def val_epoch(self, diter, length, batch_size, epoch, writer, val_steps): | |
loss, aux_loss = 0, 0 | |
step_id = 0 | |
# high_level_losses = [] | |
# low_level_total_losses = [] | |
val_high_losses = [] | |
val_low_losses = [] | |
if self.config.DAGGER.INTER_MODULE_ATTN: | |
self.actor_critic.eval() | |
else: | |
self.high_level.eval() | |
self.low_level.eval() | |
correct_labels = 0 | |
total_correct=0 | |
with torch.no_grad(): | |
for batch in tqdm.tqdm( | |
diter, total=length // batch_size, leave=False | |
): | |
( observations_batch, | |
prev_actions_batch, | |
not_done_masks, | |
corrected_actions_batch, | |
oracle_stop_batch | |
) = batch | |
if self.config.DAGGER.INTER_MODULE_ATTN: | |
high_recurrent_hidden_states = torch.zeros( | |
self.actor_critic.num_recurrent_layers, | |
self.config.DAGGER.BATCH_SIZE, | |
self.config.MODEL.STATE_ENCODER.hidden_size, | |
device=self.device, | |
) | |
low_recurrent_hidden_states = torch.zeros( | |
self.actor_critic.num_recurrent_layers, | |
self.config.DAGGER.BATCH_SIZE, | |
self.config.MODEL.STATE_ENCODER.hidden_size, | |
device=self.device2, | |
) | |
else: | |
high_recurrent_hidden_states = torch.zeros( | |
self.high_level.state_encoder.num_recurrent_layers, | |
self.config.DAGGER.BATCH_SIZE, | |
self.config.MODEL.STATE_ENCODER.hidden_size, | |
device=self.device, | |
) | |
low_recurrent_hidden_states = torch.zeros( | |
self.low_level.state_encoder.num_recurrent_layers, | |
self.config.DAGGER.BATCH_SIZE, | |
self.config.MODEL.STATE_ENCODER.hidden_size, | |
device=self.device2, | |
) | |
detached_state_low = None | |
batch_split = split_batch_tbptt(observations_batch, prev_actions_batch, not_done_masks, | |
corrected_actions_batch, oracle_stop_batch, self.config.DAGGER.tbptt_steps, | |
self.config.DAGGER.split_dim) | |
del observations_batch, prev_actions_batch, not_done_masks, corrected_actions_batch, batch | |
for split in batch_split: | |
( observations_batch, | |
prev_actions_batch, | |
not_done_masks, | |
corrected_actions_batch, | |
oracle_stop_batch | |
) = split | |
observations_batch = { | |
k: v.to(device=self.device, non_blocking=True) | |
for k, v in observations_batch.items() | |
} | |
loss, high_recurrent_hidden_states, low_recurrent_hidden_states, detached_state_low, correct, total= self._update_agent_val( | |
observations_batch, | |
prev_actions_batch.to( | |
device=self.device, non_blocking=True | |
), | |
not_done_masks.to( | |
device=self.device, non_blocking=True | |
), | |
corrected_actions_batch.to( | |
device=self.device, non_blocking=True | |
), | |
oracle_stop_batch.to( | |
device=self.device, non_blocking=True | |
), | |
high_recurrent_hidden_states, | |
low_recurrent_hidden_states, | |
detached_state_low | |
) | |
correct_labels+= correct | |
total_correct+=total | |
writer.add_scalar(f"Val High Level Action Loss", loss[0], val_steps) | |
writer.add_scalar(f"Val Low_level Total Loss", loss[1]+loss[2], val_steps) | |
val_steps += 1 | |
val_low_losses.append(loss[0]) | |
val_high_losses.append(loss[1]+loss[2]) | |
final_accuracy = 100 * correct_labels / total_correct | |
writer.add_scalar(f"Val High level Loss epoch", np.mean(val_high_losses), epoch) | |
writer.add_scalar(f"Val Low level Loss epoch", np.mean(val_low_losses), epoch) | |
writer.add_scalar(f"Validation Accuracy", final_accuracy, epoch) | |
return val_steps | |
def _update_agent( | |
self, observations, prev_actions, not_done_masks, corrected_actions, oracle_stop, high_recurrent_hidden_states, | |
low_recurrent_hidden_states, detached_state_low | |
): | |
self.optimizer_high_level.zero_grad() | |
self.optimizer_low_level.zero_grad() | |
high_level_criterion = nn.CrossEntropyLoss(ignore_index=-1, reduction="mean") | |
low_level_criterion = nn.MSELoss() | |
low_level_stop_criterion = nn.BCEWithLogitsLoss() | |
AuxLosses.clear() | |
high_recurrent_hidden_states = repackage_hidden(high_recurrent_hidden_states) | |
low_recurrent_hidden_states = repackage_hidden(low_recurrent_hidden_states) | |
if self.config.DAGGER.INTER_MODULE_ATTN: | |
high_level_action_mask = observations['vln_oracle_action_sensor'] ==0 | |
observations['vln_oracle_action_sensor'] = observations['vln_oracle_action_sensor'].to(dtype=torch.int64) | |
discrete_actions = observations['vln_oracle_action_sensor'] | |
discrete_action_mask = discrete_actions ==0 | |
discrete_actions = (discrete_actions-1).masked_fill_(discrete_action_mask, 4) | |
batch = (observations, high_recurrent_hidden_states, low_recurrent_hidden_states, | |
prev_actions, not_done_masks, discrete_actions.view(-1), detached_state_low) | |
high_output, low_output, low_stop_output, high_recurrent_hidden_states, low_recurrent_hidden_states, detached_state_low = self.actor_critic(batch) | |
del batch | |
high_output = high_output.masked_fill_(high_level_action_mask, 0) | |
high_level_loss = high_level_criterion(high_output,(observations['vln_oracle_action_sensor']-1).squeeze(1)) | |
high_level_loss.backward() | |
high_level_loss_data = high_level_loss.detach() | |
del high_output | |
self.optimizer_high_level.step() | |
oracle_stop = oracle_stop.to(self.device2) | |
corrected_actions = corrected_actions.to(self.device2) | |
action_mask = corrected_actions==0 | |
low_output = low_output.masked_fill_(action_mask, 0) | |
low_output = low_output.to(dtype=torch.float) | |
corrected_actions = corrected_actions.to(dtype=torch.float) | |
low_level_action_loss = low_level_criterion(low_output, corrected_actions) | |
mask = (oracle_stop!=-1) | |
oracle_stop = torch.masked_select(oracle_stop, mask) | |
low_stop_output = torch.masked_select(low_stop_output, mask) | |
low_level_stop_loss = low_level_stop_criterion(low_stop_output, oracle_stop) | |
low_level_loss = low_level_action_loss + low_level_stop_loss | |
low_level_loss.backward() | |
self.optimizer_low_level.step() | |
else: | |
batch = (observations, high_recurrent_hidden_states, prev_actions, not_done_masks) | |
output, high_recurrent_hidden_states = self.high_level(batch) | |
del batch | |
high_level_action_mask = observations['vln_oracle_action_sensor'] ==0 | |
output = output.masked_fill_(high_level_action_mask, 0) | |
observations['vln_oracle_action_sensor'] = observations['vln_oracle_action_sensor'].squeeze(1).to(dtype=torch.int64) | |
high_level_loss = high_level_criterion(output,(observations['vln_oracle_action_sensor']-1)) | |
high_level_loss.backward() | |
self.optimizer_high_level.step() | |
high_level_loss_data = high_level_loss.detach() | |
del output | |
self.low_level.to(self.device2) | |
observations = { | |
k: v.to(device=self.device2, non_blocking=True) | |
for k, v in observations.items() | |
} | |
discrete_actions = observations['vln_oracle_action_sensor'] | |
discrete_action_mask = discrete_actions ==0 | |
discrete_actions = (discrete_actions-1).masked_fill_(discrete_action_mask, 4) | |
del observations['vln_oracle_action_sensor'] | |
batch = (observations, | |
low_recurrent_hidden_states, | |
prev_actions.to( | |
device=self.device2, non_blocking=True | |
), | |
not_done_masks.to( | |
device=self.device2, non_blocking=True | |
), | |
discrete_actions.view(-1)) | |
del observations, prev_actions, not_done_masks | |
oracle_stop = oracle_stop.to(self.device2) | |
output, stop_out, low_recurrent_hidden_states = self.low_level(batch) | |
corrected_actions = corrected_actions.to(self.device2) | |
action_mask = corrected_actions==0 | |
output = output.masked_fill_(action_mask, 0) | |
output = output.to(dtype=torch.float) | |
corrected_actions = corrected_actions.to(dtype=torch.float) | |
low_level_action_loss = low_level_criterion(output, corrected_actions) | |
mask = (oracle_stop!=-1) | |
oracle_stop = torch.masked_select(oracle_stop, mask) | |
stop_out = torch.masked_select(stop_out, mask) | |
low_level_stop_loss = low_level_stop_criterion(stop_out, oracle_stop) | |
low_level_loss = low_level_action_loss + low_level_stop_loss | |
low_level_loss.backward() | |
self.optimizer_low_level.step() | |
aux_loss_data =0 | |
loss = (high_level_loss_data.item(), low_level_action_loss.detach().item(), | |
low_level_stop_loss.detach().item(), aux_loss_data) | |
return loss, high_recurrent_hidden_states, low_recurrent_hidden_states, detached_state_low | |
def _update_agent_val( | |
self, observations, prev_actions, not_done_masks, corrected_actions, oracle_stop, high_recurrent_hidden_states, | |
low_recurrent_hidden_states, detached_state_low | |
): | |
high_level_criterion = nn.CrossEntropyLoss(ignore_index=-1, reduction="mean") | |
low_level_criterion = nn.MSELoss() | |
low_level_stop_criterion = nn.BCEWithLogitsLoss() | |
AuxLosses.clear() | |
high_recurrent_hidden_states = repackage_hidden(high_recurrent_hidden_states) | |
low_recurrent_hidden_states = repackage_hidden(low_recurrent_hidden_states) | |
if self.config.DAGGER.INTER_MODULE_ATTN: | |
high_level_action_mask = observations['vln_oracle_action_sensor'] ==0 | |
observations['vln_oracle_action_sensor'] = observations['vln_oracle_action_sensor'].to(dtype=torch.int64) | |
discrete_actions = observations['vln_oracle_action_sensor'] | |
discrete_action_mask = discrete_actions ==0 | |
discrete_actions = (discrete_actions-1).masked_fill_(discrete_action_mask, 4) | |
batch = (observations, high_recurrent_hidden_states, low_recurrent_hidden_states, | |
prev_actions, not_done_masks, discrete_actions.view(-1), detached_state_low) | |
high_output, low_output, low_stop_output, high_recurrent_hidden_states, low_recurrent_hidden_states, detached_state_low = self.actor_critic(batch) | |
del batch | |
high_output = high_output.masked_fill_(high_level_action_mask, 0) | |
high_level_loss = high_level_criterion(high_output,(observations['vln_oracle_action_sensor']-1).squeeze(1)) | |
oracle_stop = oracle_stop.to(self.device2) | |
corrected_actions = corrected_actions.to(self.device2) | |
action_mask = corrected_actions==0 | |
low_output = low_output.masked_fill_(action_mask, 0) | |
low_output = low_output.to(dtype=torch.float) | |
corrected_actions = corrected_actions.to(dtype=torch.float) | |
low_level_action_loss = low_level_criterion(low_output, corrected_actions) | |
mask = (oracle_stop!=-1) | |
oracle_stop = torch.masked_select(oracle_stop, mask) | |
low_stop_output = torch.masked_select(low_stop_output, mask) | |
low_level_stop_loss = low_level_stop_criterion(low_stop_output, oracle_stop) | |
else: | |
batch = (observations, high_recurrent_hidden_states, prev_actions, not_done_masks) | |
output, high_recurrent_hidden_states = self.high_level(batch) | |
del batch | |
high_level_action_mask = observations['vln_oracle_action_sensor'] ==0 | |
output = output.masked_fill_(high_level_action_mask, 0) | |
observations['vln_oracle_action_sensor'] = observations['vln_oracle_action_sensor'].squeeze(1).to(dtype=torch.int64) | |
high_level_loss = high_level_criterion(output,(observations['vln_oracle_action_sensor']-1)) | |
predicted = torch.argmax(output, dim=1) | |
corrected_mask = ~high_level_action_mask | |
correct = torch.masked_select((observations['vln_oracle_action_sensor']-1), corrected_mask) | |
predicted = torch.masked_select(predicted, corrected_mask) | |
output = output.masked_fill_(high_level_action_mask, 0) | |
accuracy = (predicted == correct).sum().item() | |
total = predicted.size(0) | |
del output | |
self.low_level.to(self.device2) | |
observations = { | |
k: v.to(device=self.device2, non_blocking=True) | |
for k, v in observations.items() | |
} | |
batch = (observations, | |
low_recurrent_hidden_states, | |
prev_actions.to( | |
device=self.device2, non_blocking=True | |
), | |
not_done_masks.to( | |
device=self.device2, non_blocking=True | |
), | |
observations['vln_oracle_action_sensor']-1) | |
del observations, prev_actions, not_done_masks | |
oracle_stop = oracle_stop.to(self.device2) | |
output, stop_out, low_recurrent_hidden_states = self.low_level(batch) | |
corrected_actions = corrected_actions.to(self.device2) | |
action_mask = corrected_actions==0 | |
output = output.masked_fill_(action_mask, 0) | |
output = output.to(dtype=torch.float) | |
corrected_actions = corrected_actions.to(dtype=torch.float) | |
low_level_action_loss = low_level_criterion(output, corrected_actions) | |
mask = (oracle_stop!=-1) | |
oracle_stop = torch.masked_select(oracle_stop, mask) | |
stop_out = torch.masked_select(stop_out, mask) | |
low_level_stop_loss = low_level_stop_criterion(stop_out, oracle_stop) | |
aux_loss_data =0 | |
# low_level_loss_data = low_level_loss.detach() | |
loss = (high_level_loss.item(), low_level_action_loss.item(), | |
low_level_stop_loss.item(), aux_loss_data) | |
return loss, high_recurrent_hidden_states, low_recurrent_hidden_states, detached_state_low, accuracy, total | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment