Before training:
last = model.state_dict()
Inside training loop, after computing loss:
if torch.isnan(loss).sum().item():
model.load_state_dict(last)
else:
last = model.state_dict()
Before training:
last = model.state_dict()
Inside training loop, after computing loss:
if torch.isnan(loss).sum().item():
model.load_state_dict(last)
else:
last = model.state_dict()