Created
March 10, 2021 20:11
-
-
Save robieta/16b91831721cb8be121f04ed9c917375 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import multiprocessing.dummy | |
import os | |
import pickle | |
import shutil | |
import subprocess | |
import tempfile | |
import uuid | |
import torch | |
from torch.utils.benchmark import Compare, Timer | |
REF_ENV = "52875_ref" | |
PR_ENV = "52875_pr" | |
cpu = torch.device('cpu') | |
def generate_input(shape, dtype=torch.double, device=cpu): | |
eigvals = torch.rand(*shape[:-1], dtype=dtype, device=device) | |
eigvecs = torch.rand(*shape, dtype=dtype, device=device) | |
input = (eigvecs * eigvals.unsqueeze(-2)) @ eigvecs.inverse() | |
input.requires_grad_(True) | |
return input | |
def worker_main(result_dir, size=10, num_threads=1): | |
conda_env = os.path.split(os.getenv("CONDA_PREFIX"))[1] | |
assert conda_env in (REF_ENV, PR_ENV) | |
x = generate_input((size, size)) | |
eigvals, eigvecs = torch.eig(x, eigenvectors=True) | |
timer = Timer( | |
""" | |
torch.autograd.backward( | |
(eigvals, eigvecs), | |
(onesvals, onesvecs), | |
retain_graph=True | |
) | |
""", | |
globals={ | |
"x": x, | |
"eigvals": eigvals, | |
"eigvecs": eigvecs, | |
"onesvals": torch.ones_like(eigvals), | |
"onesvecs": torch.ones_like(eigvecs) | |
}, | |
label="Eig backward", | |
sub_label=".", | |
description=f"size: {size}", | |
env=conda_env, | |
num_threads=num_threads, | |
) | |
m = timer.blocked_autorange(min_run_time=5) | |
with open(os.path.join(result_dir, f"{uuid.uuid4()}.pkl"), "wb") as f: | |
pickle.dump(m, f) | |
def map_fn(cmd): | |
subprocess.run(cmd, shell=True, check=True) | |
def main(): | |
result_dir = tempfile.mkdtemp() | |
try: | |
cmds = [ | |
f"source activate {env} && python {__file__} --mode worker " | |
f"--result_dir {result_dir} --size {size} --num_threads {num_threads}" | |
for env in (REF_ENV, PR_ENV) | |
for size in (10, 100, 1000) | |
for num_threads in (1, 2, 4) | |
] | |
with multiprocessing.dummy.Pool(4) as pool: | |
pool.map(map_fn, cmds) | |
results = [] | |
for i in os.listdir(result_dir): | |
with open(os.path.join(result_dir, i), "rb") as f: | |
results.append(pickle.load(f)) | |
compare = Compare(results) | |
compare.trim_significant_figures() | |
compare.colorize() | |
compare.print() | |
finally: | |
shutil.rmtree(result_dir) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--mode", default="main") | |
parser.add_argument("--size", default=10, type=int) | |
parser.add_argument("--num_threads", default=1, type=int) | |
parser.add_argument("--result_dir", type=str) | |
args = parser.parse_args() | |
if args.mode == "main": | |
main() | |
else: | |
assert args.mode == "worker" | |
worker_main(args.result_dir, size=args.size, num_threads=args.num_threads) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment