Skip to content

Instantly share code, notes, and snippets.

@robieta
Created March 10, 2021 20:11
Show Gist options
  • Save robieta/16b91831721cb8be121f04ed9c917375 to your computer and use it in GitHub Desktop.
Save robieta/16b91831721cb8be121f04ed9c917375 to your computer and use it in GitHub Desktop.
import argparse
import multiprocessing.dummy
import os
import pickle
import shutil
import subprocess
import tempfile
import uuid
import torch
from torch.utils.benchmark import Compare, Timer
REF_ENV = "52875_ref"
PR_ENV = "52875_pr"
cpu = torch.device('cpu')
def generate_input(shape, dtype=torch.double, device=cpu):
eigvals = torch.rand(*shape[:-1], dtype=dtype, device=device)
eigvecs = torch.rand(*shape, dtype=dtype, device=device)
input = (eigvecs * eigvals.unsqueeze(-2)) @ eigvecs.inverse()
input.requires_grad_(True)
return input
def worker_main(result_dir, size=10, num_threads=1):
conda_env = os.path.split(os.getenv("CONDA_PREFIX"))[1]
assert conda_env in (REF_ENV, PR_ENV)
x = generate_input((size, size))
eigvals, eigvecs = torch.eig(x, eigenvectors=True)
timer = Timer(
"""
torch.autograd.backward(
(eigvals, eigvecs),
(onesvals, onesvecs),
retain_graph=True
)
""",
globals={
"x": x,
"eigvals": eigvals,
"eigvecs": eigvecs,
"onesvals": torch.ones_like(eigvals),
"onesvecs": torch.ones_like(eigvecs)
},
label="Eig backward",
sub_label=".",
description=f"size: {size}",
env=conda_env,
num_threads=num_threads,
)
m = timer.blocked_autorange(min_run_time=5)
with open(os.path.join(result_dir, f"{uuid.uuid4()}.pkl"), "wb") as f:
pickle.dump(m, f)
def map_fn(cmd):
subprocess.run(cmd, shell=True, check=True)
def main():
result_dir = tempfile.mkdtemp()
try:
cmds = [
f"source activate {env} && python {__file__} --mode worker "
f"--result_dir {result_dir} --size {size} --num_threads {num_threads}"
for env in (REF_ENV, PR_ENV)
for size in (10, 100, 1000)
for num_threads in (1, 2, 4)
]
with multiprocessing.dummy.Pool(4) as pool:
pool.map(map_fn, cmds)
results = []
for i in os.listdir(result_dir):
with open(os.path.join(result_dir, i), "rb") as f:
results.append(pickle.load(f))
compare = Compare(results)
compare.trim_significant_figures()
compare.colorize()
compare.print()
finally:
shutil.rmtree(result_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--mode", default="main")
parser.add_argument("--size", default=10, type=int)
parser.add_argument("--num_threads", default=1, type=int)
parser.add_argument("--result_dir", type=str)
args = parser.parse_args()
if args.mode == "main":
main()
else:
assert args.mode == "worker"
worker_main(args.result_dir, size=args.size, num_threads=args.num_threads)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment