Skip to content

Instantly share code, notes, and snippets.

@mrshenli
Last active January 17, 2021 02:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mrshenli/8c39c35612218e0d6d772910bca2f737 to your computer and use it in GitHub Desktop.
Save mrshenli/8c39c35612218e0d6d772910bca2f737 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
from torch.distributed.pipeline.sync import Pipe
from torch.distributed import rpc
N = 2000
B = 10
def func(x):
y = torch.zeros_like(x)
for _ in range(10):
y += x
return y
# warmup
for d in range(torch.cuda.device_count()):
with torch.cuda.device(d):
for _ in range(3):
func(torch.ones([N, N], device="cuda"))
# measure one device
e0 = torch.cuda.Event(enable_timing=True)
e1 = torch.cuda.Event(enable_timing=True)
with torch.cuda.device(0):
e0.record()
func(torch.ones([N, N], device="cuda"))
e1.record()
e1.synchronize()
print(f"e1 - e0: {e0.elapsed_time(e1)}")
# measure pipe
rpc.init_rpc(
"worker0",
rank=0,
world_size=1,
rpc_backend_options=rpc.TensorPipeRpcBackendOptions(
init_method="tcp://localhost:23456"
)
)
model = nn.Sequential(
nn.Linear(N, N).to("cuda:0"),
nn.Linear(N, N).to("cuda:1")
)
pipe = Pipe(model, chunks=2)
inp = torch.zeros(B, N).to("cuda:0")
# warmup
for _ in range(3):
pipe(inp).local_value().sum().backward()
print(f"peak mem - cuda:0 = {torch.cuda.memory_stats(0)['allocated_bytes.all.peak']}")
print(f"peak mem - cuda:1 = {torch.cuda.memory_stats(1)['allocated_bytes.all.peak']}")
# record event
e_bfr_fp0 = torch.cuda.Event(enable_timing=True)
e_aft_fp0 = torch.cuda.Event(enable_timing=True)
e_aft_fp1 = torch.cuda.Event(enable_timing=True)
e_aft_bp0 = torch.cuda.Event(enable_timing=True)
with torch.cuda.device(0):
e_bfr_fp0.record()
out = pipe(inp).local_value()
with torch.cuda.device(1):
e_aft_fp1.record()
with torch.cuda.device(0):
e_aft_fp1.wait(torch.cuda.current_stream())
e_aft_fp0.record()
out.sum().backward()
with torch.cuda.device(0):
e_aft_bp0.record()
e_aft_bp0.synchronize()
print(f"fw: {e_bfr_fp0.elapsed_time(e_aft_fp0)}")
print(f"fw + bw: {e_bfr_fp0.elapsed_time(e_aft_bp0)}")
print(f"peak mem - cuda:0 = {torch.cuda.memory_stats(0)['allocated_bytes.all.peak']}")
print(f"peak mem - cuda:1 = {torch.cuda.memory_stats(1)['allocated_bytes.all.peak']}")
del pipe
torch.distributed.rpc.shutdown()
"""
outputs:
e1 - e0: 0.9849920272827148 │·
peak mem - cuda:0 = 50517504 │·
peak mem - cuda:1 = 50591744 │·
fw: 2.613663911819458 │·
fw + bw: 4.876895904541016 │·
peak mem - cuda:0 = 50517504 │·
peak mem - cuda:1 = 50672128
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment