Skip to content

Instantly share code, notes, and snippets.

@adriantre
Created May 7, 2020 09:42
Show Gist options
  • Save adriantre/829a237afec27290eeb58b8142e5ea68 to your computer and use it in GitHub Desktop.
Save adriantre/829a237afec27290eeb58b8142e5ea68 to your computer and use it in GitHub Desktop.
Returning val_loss in validation_epoch_end results in TypeError: can't pickle _thread.lock objects
# Version: pytorch-lightning==0.7.5
# Run: python -m lightning_run --batch_size 5 --num_workers 0 --max_epochs 1
def validation_step(self, batch, batch_idx):
inputs, targets = batch
preds, val_loss = self.forward(inputs, targets)
val_acc = torch.div(
torch.sum(preds == targets), len(targets) * 1.0,
)
return {
"val_loss": val_loss,
"val_acc": val_acc,
}
def validation_epoch_end(self, outputs):
val_loss_mean = 0
val_acc_mean = 0
for output in outputs:
val_loss_mean += output["val_loss"]
val_acc_mean += output["val_acc"]
val_loss_mean /= len(outputs) # tensor(0.6985)
val_acc_mean /= len(outputs) # tensor(0.0)
tqdm_dict = {
"val_loss": val_loss_mean, # Breaking change
"val_acc": val_acc_mean,
}
results = {
"val_loss": val_loss_mean, # Breaking change
"progress_bar": tqdm_dict,
"log": tqdm_dict,
}
return results
# Providing external dataloaders
trainer.fit(
model,
train_dataloader=train_dataloader,
val_dataloaders=val_dataloader
)
# Stacktrace
Traceback (most recent call last):
File "/opt/conda/lib/python3.7/runpy.py", line 193, in _run_module_as_main
"__main__", mod_spec)
File "/opt/conda/lib/python3.7/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/home/lightning_run.py", line 116, in <module>
model, train_dataloader=train_dataloader, val_dataloaders=val_dataloader
File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 793, in fit
self.run_pretrain_routine(model)
File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 913, in run_pretrain_routine
self.train()
File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 347, in train
self.run_training_epoch()
File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 452, in run_training_epoch
self.call_checkpoint_callback()
File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 789, in call_checkpoint_callback
self.checkpoint_callback.on_validation_end(self, self.get_model())
File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py", line 10, in wrapped_fn
return fn(*args, **kwargs)
File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 231, in on_validation_end
self._do_check_save(filepath, current, epoch)
File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 265, in _do_check_save
self._save_model(filepath)
File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 142, in _save_model
self.save_function(filepath)
File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_io.py", line 253, in save_checkpoint
self._atomic_save(checkpoint, filepath)
File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_io.py", line 244, in _atomic_save
torch.save(checkpoint, tmp_path)
File "/opt/conda/lib/python3.7/site-packages/torch/serialization.py", line 328, in save
_legacy_save(obj, opened_file, pickle_module, pickle_protocol)
File "/opt/conda/lib/python3.7/site-packages/torch/serialization.py", line 401, in _legacy_save
pickler.dump(obj)
TypeError: can't pickle _thread.lock objects
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment