Skip to content

Instantly share code, notes, and snippets.

@jesuscast
Last active October 16, 2017 05:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jesuscast/42a2aae76b5f02e28bdfd1e76bb36f77 to your computer and use it in GitHub Desktop.
Save jesuscast/42a2aae76b5f02e28bdfd1e76bb36f77 to your computer and use it in GitHub Desktop.
torch_qr_fail.py
import torch
import torch.nn as nn
import multiprocessing as mp
def get_flattened(tensor):
""" This is just the first part of orthogonal initialization.
Taken directly from https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
"""
rows = tensor.size(0)
cols = tensor[0].numel()
flattened = torch.Tensor(rows, cols).normal_(0, 1)
return flattened
def get_mat():
""" This is what we send when during the initialization """
wc_dim = 200
mat = nn.utils.weight_norm(nn.Linear(wc_dim, wc_dim, bias=False), name="weight")
tensor = mat.weight.data
return tensor
def exp():
tensor = get_mat()
flattened = get_flattened(tensor)
# It should fail if being called
# in a child process.
return torch.qr(flattened)
if __name__=="__main__":
# First call
result = exp()
assert result is not None, "Oops, it should have not failed here"
# Now create child process.
p = mp.Process(target=exp)
p.start()
p.join()
assert p.exitcode != 0, "Child process call to qr did not fail"
assert p.exitcode == -11, "The error is not a segmentation fault %s" % p.exitcode
print("Child process call to QR failed as expected")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment