Skip to content

Instantly share code, notes, and snippets.

@matt-peters
Created June 15, 2020 18:02
Show Gist options
  • Save matt-peters/2c1c4e7eb48eec2029b013935f3c4646 to your computer and use it in GitHub Desktop.
Save matt-peters/2c1c4e7eb48eec2029b013935f3c4646 to your computer and use it in GitHub Desktop.
Pytorch + tensorflow running together in the same computational graph
"""
An example of running both pytorch and tensorflow in the same network,
while pasing weights and gradients between the two.
In this example, we run a simple 2-layer feed-forward network,
with the first layer size (5, 2) and the second (2, 3).
The code contains an implementation of forward/backward passes with
three versions:
* tensorflow only
* pytorch only
* pytorch for first layer then tensorflow for second layer where
activations/gradients are passed between the layers.
"""
import numpy as np
# Define the input and weights here so all implementations are using
# the same underlying data and parameters.
batch_size = 16
embed_dim = 5
np.random.seed(5)
x_input = np.random.normal(size=(batch_size, embed_dim)).astype(np.float32)
np_W1 = np.random.normal(size=(embed_dim, 2)).astype(np.float32)
np_W2 = np.random.normal(size=(2, 3)).astype(np.float32)
def run_tf():
import tensorflow as tf
with tf.Session() as session:
x = tf.placeholder(tf.float32, shape=(None, embed_dim), name='x')
# MLP: (5, 2), (2, 3)
W1 = tf.Variable(np_W1)
W2 = tf.Variable(np_W2)
session.run(tf.global_variables_initializer())
# (batch_size, 3)
y = tf.matmul(tf.matmul(x, W1), W2)
loss = tf.reduce_mean(y ** 2)
gradient_op = tf.gradients(loss, [W1, W2])
W1_grad, W2_grad = session.run(gradient_op, feed_dict={x: x_input})
print(loss)
print(W1_grad)
print(W2_grad)
def run_torch():
import torch
class FF(torch.nn.Module):
def __init__(self):
super().__init__()
self.W1 = torch.nn.Linear(5, 2, bias=False)
self.W1.weight.data.copy_(torch.tensor(np_W1).t())
self.W2 = torch.nn.Linear(2, 3, bias=False)
self.W2.weight.data.copy_(torch.tensor(np_W2).t())
def forward(self, x):
return torch.mean(self.W2(self.W1(x)) ** 2)
ff = FF()
loss = ff(torch.tensor(x_input))
loss.backward()
print(loss)
print(ff.W1.weight.grad.t())
print(ff.W2.weight.grad.t())
def run_both():
import tensorflow as tf
import torch
session = tf.Session()
# torch W1
class FF(torch.nn.Module):
def __init__(self):
super().__init__()
self.W1 = torch.nn.Linear(5, 2, bias=False)
self.W1.weight.data.copy_(torch.tensor(np_W1).t())
def forward(self, x):
return self.W1(x)
ff = FF()
# tensorflow W2
# output from W1/input to W2 is size (batch_size, 2)
tf_w1_output = tf.placeholder(tf.float32, shape=(None, 2), name='x')
W2 = tf.Variable(np_W2)
session.run(tf.global_variables_initializer())
y = tf.matmul(tf_w1_output, W2)
loss = tf.reduce_mean(y ** 2)
# take gradient of parameters and of input activations to W2.
gradient_op = tf.gradients(loss, [tf_w1_output, W2])
# Finished setup. Now run computation.
# forward torch
torch_w1_output = ff(torch.tensor(x_input))
# forward tf, then backward tf
w1_output_grad, W2_grad = session.run(gradient_op, feed_dict={tf_w1_output: torch_w1_output.detach().cpu().numpy()})
# backward torch
torch_w1_output.backward(torch.tensor(w1_output_grad))
print(ff.W1.weight.grad.t())
print(W2_grad)
@pangyuteng
Copy link

wow. this is pretty neat!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment