-
-
Save ChuhuaW/d0065dada5be74e37721e769c1fff2b9 to your computer and use it in GitHub Desktop.
Nuscenes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sys | |
import numpy as np | |
import torch | |
from torch.utils import data | |
import json | |
import pdb | |
import pickle | |
import random | |
from nuscenes import NuScenes | |
from nuscenes.eval.prediction.splits import get_prediction_challenge_split | |
from nuscenes.prediction import PredictHelper, convert_global_coords_to_local, convert_local_coords_to_global | |
import pickle | |
from lib.dataloaders.static_layers import StaticLayerRasterizer | |
from PIL import Image | |
import torchvision.transforms.functional as F | |
from torchvision import transforms | |
def chunks(lst, n): | |
for i in range(0, len(lst), n): | |
yield lst[i:i + n] | |
class NUSCENESDataLayer(data.Dataset): | |
def __init__(self, args, split): | |
self.args = args | |
self.split = split | |
self.batch_size = args.batch_size | |
self.DATAROOT = './data/nuscenes' | |
self.map_root = os.path.join(self.DATAROOT,'token_lane_img',self.split) | |
self.mask_root = os.path.join(self.DATAROOT,'token_mask',self.split) | |
#self.map_root = os.path.join(self.DATAROOT,'token_img',self.split) | |
# self.nuscenes = NuScenes('v1.0-trainval', dataroot=self.DATAROOT) | |
# with open('lib/dataloaders/v1.0-trainval.pickle', 'wb') as handle: | |
# pickle.dump(self.nuscenes, handle, protocol=pickle.HIGHEST_PROTOCOL) | |
with open('lib/dataloaders/v1.0-trainval.pickle', 'rb') as handle: | |
self.nuscenes = pickle.load(handle) | |
self.tokens = get_prediction_challenge_split(self.split, dataroot=self.DATAROOT) | |
self.helper = PredictHelper(self.nuscenes) | |
self.static_layer_rasterizer = StaticLayerRasterizer(self.helper, layer_names = ['lane', 'road_segment', 'drivable_area','road_divider', 'lane_divider']) | |
self.dataset = [] | |
self.vah = True | |
self.all_traj_list = [] | |
for token in self.tokens: | |
instance_token, sample_token = token.split("_") | |
future_traj_local = self.helper.get_future_for_agent(instance_token, sample_token, seconds=self.args.dec_steps//2, in_agent_frame=True) | |
future_traj_global = self.helper.get_future_for_agent(instance_token, sample_token, seconds=self.args.dec_steps//2, in_agent_frame=False) | |
if future_traj_local.shape[0] < 12: | |
import pdb; pdb.set_trace() | |
continue | |
elif future_traj_local.shape[0] == 12: | |
current_traj_dict = self.helper.get_sample_annotation(instance_token, sample_token) | |
current_traj_global = current_traj_dict['translation'][:2] | |
assert current_traj_dict['category_name'].startswith('vehicle') | |
starting_translation, starting_rotation = current_traj_dict['translation'], current_traj_dict['rotation'] | |
past_traj_local = self.helper.get_past_for_agent(instance_token, sample_token, seconds=self.args.enc_steps//2, in_agent_frame=True)[::-1] | |
past_traj_dict = self.helper.get_past_for_agent(instance_token, sample_token, seconds=self.args.enc_steps//2, in_agent_frame=True, just_xy=False)[::-1] | |
input_tokens = [(p['instance_token'], p['sample_token']) for p in past_traj_dict] | |
# use_local_coords | |
input_traj = np.vstack((past_traj_local,[0,0])) | |
if self.vah: | |
input_tokens = np.vstack((input_tokens,[instance_token, sample_token])) | |
vah = self.get_velo_acc_head(input_tokens) | |
input_traj = np.hstack((input_traj,vah)) | |
first_history_index = self.args.enc_steps - input_traj.shape[0] | |
all_traj = np.vstack((input_traj[:,:2],future_traj_local)) | |
if input_traj.shape[0]<self.args.enc_steps: | |
zero_padding = np.zeros((first_history_index,self.args.input_dim)) | |
all_traj = np.vstack((zero_padding[:,:2], all_traj)) | |
input_traj = np.vstack((zero_padding, input_traj)) | |
target_traj = self.get_target(all_traj, first_history_index, self.args.enc_steps, self.args.enc_steps, self.args.dec_steps) | |
map_mask = os.path.join(self.map_root, token + '.png') | |
self.dataset.append([first_history_index, input_traj, target_traj, future_traj_global, starting_translation, starting_rotation, token, map_mask ]) | |
else: | |
print('Future Traj length > 12') | |
import pdb; pdb.set_trace() | |
raise ValueError() | |
#import pdb; pdb.set_trace() | |
self.len_dict = {} | |
""" | |
self.len_dict[1] = [1,2,3,4,5,6] | |
self.len_dict[2] = [7,8,9] | |
""" | |
for index in range(len(self.dataset)): | |
first_history_index, _, _, _, _, _ ,_ , _= self.dataset[index] | |
if first_history_index not in self.len_dict: | |
self.len_dict[first_history_index] = [] | |
self.len_dict[first_history_index].append(index) | |
self.shuffle_dataset() | |
def shuffle_dataset(self): | |
self._init_inputs() | |
def _init_inputs(self): | |
self.inputs = [] | |
for length in self.len_dict: | |
indices = self.len_dict[length] | |
random.shuffle(indices) | |
self.inputs.extend(list(chunks(self.len_dict[length], self.batch_size))) | |
def __len__(self): | |
return len(self.inputs) | |
def get_velo_acc_head(self, input_tokens): | |
vah = np.zeros((input_tokens.shape[0],3)) | |
for i in reversed(range(input_tokens.shape[0])): | |
inst, sample = input_tokens[i] | |
vah[i,0] = self.helper.get_velocity_for_agent(inst, sample, max_time_diff=1.5) | |
# Meters / second^2. | |
vah[i,1] = self.helper.get_acceleration_for_agent(inst, sample, max_time_diff=1.5) | |
# Radians / second. | |
vah[i,2] = self.helper.get_heading_change_rate_for_agent(inst, sample, max_time_diff=1.5) | |
return np.nan_to_num(vah) | |
def process_vah(self, vah): | |
for i in reversed(range(vah.shape[0])): | |
if i == vah.shape[0] -1: | |
if np.isnan(vah[i,0]): | |
vah[i,0] = 0.0 | |
if np.isnan(vah[i,1]): | |
vah[i,1] = 0.0 | |
if np.isnan(vah[i,2]): | |
vah[i,2] = 0.0 | |
else: | |
if np.isnan(vah[i,0]): | |
vah[i,0] = vah[i+1,0] | |
if np.isnan(vah[i,1]): | |
vah[i,1] = vah[i+1,1] | |
if np.isnan(vah[i,2]): | |
vah[i,2] = vah[i+1,2] | |
return vah | |
def __getitem__(self, index): | |
indices = self.inputs[index] | |
ret = { | |
'input_x': [], | |
'target_y': [], | |
'first_history_index':[], | |
'y_raw': [], | |
'starting_translation': [], | |
'starting_rotation':[], | |
'token':[], | |
'map_mask':[] | |
} | |
for idx in indices: | |
this_ret = self.getitem_one(idx) | |
ret['input_x'].append(torch.as_tensor(this_ret['input_x']).type(torch.FloatTensor)) | |
ret['target_y'].append(torch.as_tensor(this_ret['target_y']).type(torch.FloatTensor)) | |
ret['y_raw'].append(torch.as_tensor(this_ret['y_raw']).type(torch.FloatTensor)) | |
ret['starting_translation'].append(torch.as_tensor(this_ret['starting_translation']).type(torch.FloatTensor)) | |
ret['starting_rotation'].append(torch.as_tensor(this_ret['starting_rotation']).type(torch.FloatTensor)) | |
ret['first_history_index'].append(torch.as_tensor(this_ret['first_history_index']).type(torch.LongTensor)) | |
ret['token'].append(this_ret['token']) | |
lane_mask = self.read_img(this_ret['map_mask']).copy() | |
lane_mask = transforms.ToTensor()(lane_mask) | |
ret['map_mask'].append(lane_mask) | |
ret['input_x'] = torch.stack(ret['input_x']) | |
ret['target_y'] = torch.stack(ret['target_y']) | |
ret['y_raw'] = torch.stack(ret['y_raw']) | |
ret['starting_translation'] = torch.stack(ret['starting_translation']) | |
ret['starting_rotation'] = torch.stack(ret['starting_rotation']) | |
ret['first_history_index'] = torch.stack(ret['first_history_index']) | |
ret['map_mask'] = torch.stack(ret['map_mask']) | |
return ret | |
def read_img(self, map_mask_root): | |
img = Image.open(map_mask_root) # use pillow to open a file | |
img = img.resize((100,100)) # resize the file to 256x256 | |
return img | |
def read_multichannel_img(self, map_mask_root): | |
mask = np.transpose(np.load(map_mask_root)[:, ::-1,:]*255,(1,2,0)).astype(np.uint8) | |
mask_tensor = torch.zeros((mask.shape[2],100,100 )) | |
for m in range(mask.shape[2]): | |
img = Image.fromarray(mask[:,:,m]) # use pillow to open a file | |
img = img.resize((100,100)) | |
mask_tensor[m] = (transforms.ToTensor()(img)>0).type(torch.FloatTensor) # resize the file to 256x256 | |
return mask_tensor | |
def read_npmask(self, map_mask_root): | |
mask = np.load(map_mask_root) # use pillow to open a file | |
return mask | |
def getitem_one(self, index): | |
first_history_index, x_t, y_t, y_raw, starting_translation, starting_rotation, token, map_mask = self.dataset[index] | |
ret = {} | |
ret['first_history_index'] = first_history_index | |
ret['input_x'] = x_t | |
ret['target_y'] = y_t | |
ret['y_raw'] = y_raw | |
ret['starting_translation'] = starting_translation | |
ret['starting_rotation'] = starting_rotation | |
ret['token'] = token | |
ret['map_mask'] = map_mask | |
return ret | |
def get_target(self, session, start, end, observe_length, predict_length): | |
''' | |
Given the input session and the start and end time of the input clip, find the target | |
TARGET FOR PREDICTION IS THE CHANGES IN THE FUTURE!! | |
Params: | |
session: the input time sequence of a car, can be bbox or ego_motion with shape (time, :) | |
start: start frame id | |
end: end frame id | |
Returns: | |
target: Target tensor with shape (self.args.segment_len, dec_steps, :) | |
The target is the change of the values. e.g. target of yaw is \delta{\theta}_{t0,tn} | |
''' | |
target = np.zeros((observe_length, predict_length, session.shape[-1])) | |
for i, target_start in enumerate(range(start, end), start = start): | |
'''the target of time t is the change of bbox/ego motion at times [t+1,...,t+5}''' | |
target_start = target_start + 1 | |
try: | |
target[i,:,:] = np.asarray(session[target_start:target_start+predict_length,:] - | |
session[target_start-1,:]) | |
except: | |
print("segment start: ", start) | |
print("sample start: ", target_start) | |
print("segment end: ", end) | |
print(session.shape) | |
import pdb;pdb.set_trace() | |
raise ValueError() | |
return target | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from nuscenes.eval.prediction.splits import get_prediction_challenge_split | |
import lib.utils as utl | |
from lib.dataloaders.agents import AgentBoxesWithFadedHistory | |
from lib.dataloaders.static_layers import StaticLayerRasterizer | |
from lib.dataloaders.interface import InputRepresentation | |
from nuscenes.prediction.input_representation.combinators import Rasterizer | |
from lib.dataloaders.helper import * | |
import pickle | |
from PIL import Image | |
def main(): | |
splits = ['train', 'val','train_val'] | |
DATAROOT = './data/nuscenes' | |
with open('lib/dataloaders/v1.0-trainval.pickle', 'rb') as handle: | |
nuscenes = pickle.load(handle) | |
print('nuscenes is loaded') | |
for split in splits: | |
map_root = os.path.join(DATAROOT,'lane_agent_img',split) | |
tokens = get_prediction_challenge_split(split, dataroot=DATAROOT) | |
helper = PredictHelper(nuscenes) | |
static_layer_rasterizer = StaticLayerRasterizer(helper, layer_names = ['lane', 'road_segment', 'road_block', 'ped_crossing', 'stop_line', 'carpark_area', 'walkway','drivable_area','road_divider', 'lane_divider']) | |
agent_rasterizer = AgentBoxesWithFadedHistory(helper, seconds_of_history=1.5) | |
mtp_input_representation = InputRepresentation(static_layer_rasterizer, agent_rasterizer, Rasterizer()) | |
for i, token in enumerate(tokens): | |
instance_token, sample_token = token.split("_") | |
map_mask = mtp_input_representation.make_input_representation(instance_token, sample_token)[::-1,:] | |
im = Image.fromarray(map_mask) | |
im.save(os.path.join(map_root, token + ".png"), ) | |
if __name__ == '__main__': | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
__all__ = ['SGDNet_ED_CVAE'] | |
import torch | |
import torch.nn as nn | |
from .feature_extractor import build_feature_extractor | |
from .bitrap_np import BiTraPNP | |
import torch.nn.functional as F | |
class SGDNet_ED_CVAE_DUAL(nn.Module): | |
def __init__(self, args): | |
super(SGDNet_ED_CVAE_DUAL, self).__init__() | |
self.cvae = BiTraPNP(args) | |
self.hidden_size = args.hidden_size # GRU hidden size | |
self.enc_steps = args.enc_steps # observation step | |
self.dec_steps = args.dec_steps # prediction step | |
self.dataset = args.dataset | |
self.dropout = args.dropout | |
self.feature_extractor = build_feature_extractor(args) | |
self.pred_dim = args.pred_dim | |
self.K = args.K | |
if self.dataset in ['NUSCENES']: | |
self.pred_dim = 2 | |
self.regressor = nn.Sequential(nn.Linear(self.hidden_size, | |
self.pred_dim)) | |
self.map_enc_cell = nn.GRUCell(self.hidden_size + self.hidden_size//4, self.hidden_size) | |
self.enc_goal_attn = nn.Sequential(nn.Linear(self.hidden_size//4, | |
1), | |
nn.ReLU(inplace=True)) | |
self.dec_goal_attn = nn.Sequential(nn.Linear(self.hidden_size//4, | |
1), | |
nn.ReLU(inplace=True)) | |
self.enc_to_goal_hidden = nn.Sequential(nn.Linear(self.hidden_size, | |
self.hidden_size//4), | |
nn.ReLU(inplace=True)) | |
self.goal_hidden_to_traj = nn.Sequential(nn.Linear(self.hidden_size//4, | |
self.hidden_size), | |
nn.ReLU(inplace=True)) | |
self.cvae_to_dec_hidden = nn.Sequential(nn.Linear(self.hidden_size + args.LATENT_DIM, | |
self.hidden_size), | |
nn.ReLU(inplace=True)) | |
self.enc_to_dec_hidden = nn.Sequential(nn.Linear(self.hidden_size, | |
self.hidden_size), | |
nn.ReLU(inplace=True)) | |
self.flow_dec = nn.Sequential(nn.Linear(self.hidden_size, | |
self.hidden_size), | |
nn.ReLU(inplace=True)) | |
self.goal_hidden_to_input = nn.Sequential(nn.Linear(self.hidden_size//4, | |
self.hidden_size//4), | |
nn.ReLU(inplace=True)) | |
self.dec_hidden_to_input = nn.Sequential(nn.Linear(self.hidden_size, | |
self.hidden_size), | |
nn.ReLU(inplace=True)) | |
self.goal_to_enc = nn.Sequential(nn.Linear(self.hidden_size//4, | |
self.hidden_size//4), | |
nn.ReLU(inplace=True)) | |
self.goal_to_dec = nn.Sequential(nn.Linear(self.hidden_size//4, | |
self.hidden_size//4), | |
nn.ReLU(inplace=True)) | |
self.enc_drop = nn.Dropout(self.dropout) | |
self.goal_drop = nn.Dropout(self.dropout) | |
self.dec_drop = nn.Dropout(self.dropout) | |
self.traj_enc_cell = nn.GRUCell(self.hidden_size + self.hidden_size//4, self.hidden_size) | |
self.goal_cell = nn.GRUCell(self.hidden_size//4, self.hidden_size//4) | |
self.dec_cell = nn.GRUCell(self.hidden_size, self.hidden_size) | |
def SGE(self, goal_hidden): | |
# initial goal input with zero | |
goal_input = goal_hidden.new_zeros((goal_hidden.size(0), self.hidden_size//4)) | |
# initial trajectory tensor | |
goal_traj = goal_hidden.new_zeros(goal_hidden.size(0), self.dec_steps, self.pred_dim) | |
goal_list = [] | |
for dec_step in range(self.dec_steps): | |
goal_hidden = self.goal_cell(self.goal_drop(goal_input), goal_hidden) | |
# next step input is generate by hidden | |
goal_input = self.goal_hidden_to_input(goal_hidden) | |
goal_list.append(goal_hidden) | |
# regress goal traj for loss | |
goal_traj_hidden = self.goal_hidden_to_traj(goal_hidden) | |
goal_traj[:,dec_step,:] = self.regressor(goal_traj_hidden) | |
# get goal for decoder and encoder | |
goal_for_dec = [self.goal_to_dec(goal) for goal in goal_list] | |
goal_for_enc = torch.stack([self.goal_to_enc(goal) for goal in goal_list],dim = 1) | |
#import pdb; pdb.set_trace() | |
enc_attn= self.enc_goal_attn(torch.tanh(goal_for_enc)).squeeze(-1) | |
enc_attn = F.softmax(enc_attn, dim =1).unsqueeze(1) | |
#import pdb; pdb.set_trace() | |
goal_for_enc = torch.bmm(enc_attn, goal_for_enc).squeeze(1)#view(goal_hidden.size(0), self.dec_steps, self.hidden_size//4).sum(1) | |
return goal_for_dec, goal_for_enc, goal_traj | |
def cvae_decoder(self, dec_hidden, goal_for_dec): | |
batch_size = dec_hidden.size(0) | |
K = dec_hidden.shape[1] | |
dec_hidden = dec_hidden.view(-1, dec_hidden.shape[-1]) | |
dec_traj = dec_hidden.new_zeros(batch_size, self.dec_steps, K, self.pred_dim) | |
for dec_step in range(self.dec_steps): | |
# incremental goal for each time step | |
goal_dec_input = dec_hidden.new_zeros(batch_size, self.dec_steps, self.hidden_size//4) | |
goal_dec_input_temp = torch.stack(goal_for_dec[dec_step:],dim=1) | |
goal_dec_input[:,dec_step:,:] = goal_dec_input_temp | |
dec_attn= self.dec_goal_attn(torch.tanh(goal_dec_input)).squeeze(-1) | |
dec_attn = F.softmax(dec_attn, dim =1).unsqueeze(1) | |
goal_dec_input = torch.bmm(dec_attn,goal_dec_input).squeeze(1)#.view(goal_hidden.size(0), self.dec_steps, self.hidden_size//4).sum(1) | |
#dec_input = self.dec_drop(goal_dec_input) | |
goal_dec_input = goal_dec_input.unsqueeze(1).repeat(1, K, 1).view(-1, goal_dec_input.shape[-1]) | |
dec_dec_input = self.dec_hidden_to_input(dec_hidden) | |
#dec_input = self.dec_drop(torch.cat((goal_dec_input,dec_dec_input),dim = -1)) | |
dec_input = dec_dec_input | |
#import pdb; pdb.set_trace() | |
dec_hidden = self.dec_cell(dec_input, dec_hidden) | |
# regress dec traj for loss | |
batch_traj = self.regressor(dec_hidden) | |
batch_traj = batch_traj.view(-1, K, batch_traj.shape[-1]) | |
dec_traj[:,dec_step,:,:] = batch_traj | |
return dec_traj | |
def encoder(self, raw_inputs, raw_targets, traj_input, flow_input=None, start_index = 0): | |
# initial output tensor | |
all_goal_traj = traj_input.new_zeros(traj_input.size(0), self.enc_steps, self.dec_steps, self.pred_dim) | |
all_cvae_dec_traj = traj_input.new_zeros(traj_input.size(0), self.enc_steps, self.dec_steps, self.K, self.pred_dim) | |
# initial encoder goal with zeros | |
goal_for_enc = traj_input.new_zeros((traj_input.size(0), self.hidden_size//4)) | |
# initial encoder hidden with zeros | |
traj_enc_hidden = traj_input.new_zeros((traj_input.size(0), self.hidden_size)) | |
total_probabilities = traj_input.new_zeros((traj_input.size(0), self.enc_steps, self.K)) | |
total_KLD = 0 | |
for enc_step in range(start_index, self.enc_steps): | |
traj_enc_hidden = self.traj_enc_cell(self.enc_drop(torch.cat((traj_input[:,enc_step,:], goal_for_enc), 1)), traj_enc_hidden) | |
if self.dataset in ['NUSCENES']: | |
enc_hidden = traj_enc_hidden | |
goal_hidden = self.enc_to_goal_hidden(enc_hidden) | |
goal_for_dec, goal_for_enc, goal_traj = self.SGE(goal_hidden) | |
all_goal_traj[:,enc_step,:,:] = goal_traj | |
dec_hidden = self.enc_to_dec_hidden(enc_hidden) | |
#dec_hidden = (dec_hidden + self.flow_dec(flow_input))/2 | |
if self.training: | |
cvae_hidden, KLD, probability = self.cvae(dec_hidden, raw_inputs[:,enc_step,:], self.K, raw_targets[:,enc_step,:,:]) | |
else: | |
cvae_hidden, KLD, probability = self.cvae(dec_hidden, raw_inputs[:,enc_step,:], self.K) | |
total_probabilities[:,enc_step,:] = probability | |
total_KLD += KLD | |
cvae_dec_hidden= self.cvae_to_dec_hidden(cvae_hidden) | |
cvae_dec_hidden = (cvae_dec_hidden + flow_input.unsqueeze(1))/2 | |
all_cvae_dec_traj[:,enc_step,:,:,:] = self.cvae_decoder(cvae_dec_hidden, goal_for_dec) | |
return all_goal_traj, all_cvae_dec_traj, total_KLD, total_probabilities | |
def forward(self, inputs, map_mask, targets = None, start_index = 0, training=True): | |
self.training = training | |
if torch.is_tensor(start_index): | |
start_index = start_index[0].item() | |
if self.dataset in ['NUSCENES']: | |
# if not self.training: | |
# self.K = 5 | |
#import pdb; pdb.set_trace() | |
traj_input_temp, embedded_map = self.feature_extractor([inputs[:,start_index:,:], map_mask]) | |
traj_input = traj_input_temp.new_zeros((inputs.size(0), inputs.size(1), traj_input_temp.size(-1))) | |
traj_input[:,start_index:,:] = traj_input_temp | |
map_input = embedded_map | |
all_goal_traj, all_cvae_dec_traj, KLD, total_probabilities = self.encoder(inputs, targets, traj_input, map_input, start_index) | |
return all_goal_traj, all_cvae_dec_traj, KLD, total_probabilities |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import os | |
import os.path as osp | |
import torch | |
from torch import nn, optim | |
import lib.utils as utl | |
from configs.nuscenes import parse_sgd_args as parse_args | |
from lib.models import build_model | |
from lib.losses import rmse_loss | |
import json | |
def main(args): | |
this_dir = osp.dirname(__file__) | |
model_name = args.model | |
save_dir = osp.join(this_dir, 'checkpoints', args.dataset,model_name,str(args.K), str(args.seed)) | |
if not osp.isdir(save_dir): | |
os.makedirs(save_dir) | |
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
utl.set_seed(int(args.seed)) | |
model = build_model(args) | |
optimizer = optim.Adam(model.parameters(), lr=args.lr) | |
model = nn.DataParallel(model) | |
model = model.to(device) | |
if osp.isfile(args.checkpoint): | |
checkpoint = torch.load(args.checkpoint, map_location=device) | |
model.load_state_dict(checkpoint['model_state_dict']) | |
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
args.start_epoch += checkpoint['epoch'] | |
del checkpoint | |
criterion = rmse_loss().to(device) | |
train_gen = utl.build_data_loader(args, 'train', batch_size = 1) | |
for epoch in range(args.start_epoch, args.epochs+args.start_epoch): | |
print("Number of training samples:", len(train_gen)) | |
# train | |
train_goal_loss, train_cvae_loss, train_KLD_loss = train(model, train_gen, criterion, optimizer, device) | |
print('Train Epoch: {} \t Goal loss: {:.4f}\t CVAE loss: {:.4f}\t KLD loss: {:.4f}\t Total: {:.4f}'.format( | |
epoch,train_goal_loss, train_cvae_loss, train_KLD_loss, train_goal_loss + train_cvae_loss + train_KLD_loss )) | |
if __name__ == '__main__': | |
main(parse_args()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment