Last active
October 16, 2017 05:08
-
-
Save jesuscast/42a2aae76b5f02e28bdfd1e76bb36f77 to your computer and use it in GitHub Desktop.
torch_qr_fail.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import multiprocessing as mp | |
def get_flattened(tensor): | |
""" This is just the first part of orthogonal initialization. | |
Taken directly from https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py | |
""" | |
rows = tensor.size(0) | |
cols = tensor[0].numel() | |
flattened = torch.Tensor(rows, cols).normal_(0, 1) | |
return flattened | |
def get_mat(): | |
""" This is what we send when during the initialization """ | |
wc_dim = 200 | |
mat = nn.utils.weight_norm(nn.Linear(wc_dim, wc_dim, bias=False), name="weight") | |
tensor = mat.weight.data | |
return tensor | |
def exp(): | |
tensor = get_mat() | |
flattened = get_flattened(tensor) | |
# It should fail if being called | |
# in a child process. | |
return torch.qr(flattened) | |
if __name__=="__main__": | |
# First call | |
result = exp() | |
assert result is not None, "Oops, it should have not failed here" | |
# Now create child process. | |
p = mp.Process(target=exp) | |
p.start() | |
p.join() | |
assert p.exitcode != 0, "Child process call to qr did not fail" | |
assert p.exitcode == -11, "The error is not a segmentation fault %s" % p.exitcode | |
print("Child process call to QR failed as expected") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment