Skip to content

Instantly share code, notes, and snippets.

@yaroslavvb
Last active September 23, 2019 21:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yaroslavvb/e53c83c40c8385cd90cdc15c7c61fa63 to your computer and use it in GitHub Desktop.
Save yaroslavvb/e53c83c40c8385cd90cdc15c7c61fa63 to your computer and use it in GitHub Desktop.
Example of Python multi-threading giving a mix of .grad from different backward calls
import time
import threading
import torch
import torch.nn as nn
def simple_model(d, n):
"""Creates linear neural network initialized to identity"""
layers = []
for i in range(n):
layer = nn.Linear(d, d, bias=False)
layer.weight.data.copy_(torch.eye(d))
layers.append(layer)
return torch.nn.Sequential(*layers)
def propagate(output, gradient, sleep_before1, sleep_before2, label):
def f():
time.sleep(sleep_before1)
output.backward(gradient, retain_graph=True)
grad1 = model[0].weight.grad.detach().clone()
time.sleep(sleep_before2)
grad2 = model[1].weight.grad
print(f"{label} observed gradients ", grad1[0, 0].item(), grad2[0, 0].item())
return threading.Thread(target=f, args=())
# Create simple model with two scenarios, all gradients=1, or all gradients=0
model = simple_model(2, 2)
x = torch.ones(1, 2)
y = model(x)
propagate2 = propagate(y, x, sleep_before1=0.5, sleep_before2=0, label="thread2") # observes gradients 1, 1
propagate1 = propagate(y, x - x, sleep_before1=0, sleep_before2=1, label="thread1") # should get gradients 0, 0, but instead gets 0, 1 because of another thread2
propagate1.start()
propagate2.start()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment