Skip to content

Instantly share code, notes, and snippets.

@robieta
Created August 22, 2020 00:09
Show Gist options
  • Save robieta/50cb98852eb4a72ee3db7505cb60565e to your computer and use it in GitHub Desktop.
Save robieta/50cb98852eb4a72ee3db7505cb60565e to your computer and use it in GitHub Desktop.
"""
$ python fuzz_conv.py
NOTE: After killing with ctrl-c, it is generally necessary to kill stragglers:
$ pkill -f fuzz_conv
"""
import argparse
import datetime
import multiprocessing
import multiprocessing.dummy
import os
import pickle
import queue
import subprocess
import tempfile
import threading
import time
import numpy as np
import torch
from torch.utils._benchmark import Timer, Fuzzer, FuzzedParameter, FuzzedTensor, ParameterAlias
_MIN_RUN_TIME = 1
_NUM_PER_SUBPROCESS = 5
_POOL_SIZE = int(multiprocessing.cpu_count() / 2)
_CORE_POOL = queue.Queue()
for i in range(_POOL_SIZE):
_CORE_POOL.put(2 * i)
_RESULT_FILE = "/tmp/conv_test.pkl"
_RESULT_FILE_LOCK = threading.Lock()
_MIN_POW2_SIZE = 2
_MAX_POW2_SIZE = 2048
_POW_TWO_SIZES = tuple(2 ** i for i in range(
int(np.log2(_MIN_POW2_SIZE)),
int(np.log2(_MAX_POW2_SIZE)) + 1,
))
# Wouldn't be a benchmark without ResNet!
_RESNET_SIZES = (224, 112, 56, 28, 14, 7)
def as_normalized_dict(values, minval=None):
if minval is not None:
values = tuple(i for i in values if i >= minval)
return {v: 1 / len(values) for v in values}
class ConvFuzzer(Fuzzer):
def __init__(self, seed, dtype=torch.float32):
super().__init__(
parameters=[
# Batch Size
FuzzedParameter(name="N_any", minval=1, maxval=1024, distribution="loguniform"),
FuzzedParameter(name="N_pow_2", distribution=as_normalized_dict(_POW_TWO_SIZES)),
FuzzedParameter(name="N", distribution={
ParameterAlias("N_any"): 0.5,
ParameterAlias("N_pow_2"): 0.5,
}),
# Channels
# Channels are generally either:
# A) A power of two due to taking as input the output of a prior
# convolution with power of two number of channels.
# B) Three, due to RGB
# As a result, these two are given extra probability.
FuzzedParameter(name="C_any", minval=1, maxval=1024, distribution="loguniform"),
FuzzedParameter(name="C_pow_2", distribution=as_normalized_dict(_POW_TWO_SIZES)),
FuzzedParameter(name="C", distribution={
ParameterAlias("C_any"): 0.25,
ParameterAlias("C_pow_2"): 0.65,
3: 0.1,
}),
# H and W
FuzzedParameter("H_any", minval=7, maxval=500, distribution="loguniform"),
FuzzedParameter("H_pow2", distribution=as_normalized_dict(_POW_TWO_SIZES, minval=8)),
FuzzedParameter("H_resnet", distribution=as_normalized_dict(_RESNET_SIZES)),
FuzzedParameter("W_any", minval=7, maxval=500, distribution="loguniform"),
FuzzedParameter("W_pow2", distribution=as_normalized_dict(_POW_TWO_SIZES, minval=8)),
FuzzedParameter("H", distribution={
ParameterAlias("H_any"): 0.4,
ParameterAlias("H_pow2"): 0.4,
ParameterAlias("H_resnet"): 0.2,
}),
# Square images are unusually common, so half the time Width simply
# mirrors height.
FuzzedParameter("W", distribution={
ParameterAlias("H"): 0.5,
ParameterAlias("W_any"): 0.25,
ParameterAlias("W_pow2"): 0.25,
}),
# Output channels
FuzzedParameter("out_channels_any", minval=4, maxval=1024, distribution="loguniform"),
FuzzedParameter("out_channels_pow2", distribution=as_normalized_dict(_POW_TWO_SIZES)),
FuzzedParameter("out_channels", distribution={
ParameterAlias("out_channels_any"): 0.5,
ParameterAlias("out_channels_pow2"): 0.5,
}),
# Kernel sizes and strides
FuzzedParameter("kernel_H", minval=1, maxval=7, distribution="uniform"),
FuzzedParameter("kernel_W_candidate", minval=1, maxval=7, distribution="uniform"),
FuzzedParameter("kernel_W", distribution={
ParameterAlias("kernel_H"): 0.5,
ParameterAlias("kernel_W_candidate"): 0.5,
}),
FuzzedParameter("stride_H", minval=1, maxval=3, distribution="uniform"),
FuzzedParameter("stride_W_candidate", minval=1, maxval=3, distribution="uniform"),
FuzzedParameter("stride_W", distribution={
ParameterAlias("stride_H"): 0.5,
ParameterAlias("stride_W_candidate"): 0.5,
}),
],
tensors=[
FuzzedTensor(
name="x",
size=("N", "C", "H", "W"),
# TODO(robieta): steps
probability_contiguous=0.8,
max_elements=1024 ** 2,
dtype=dtype,
cuda=False,
),
],
seed=seed,
)
class MatMulFuzzer(Fuzzer):
def __init__(self, seed, dtype=torch.float32):
super().__init__(
parameters=[
[
FuzzedParameter(name=f"K{i}_any", minval=1, maxval=2048, distribution="loguniform")
for i in range(3)
],
[
FuzzedParameter(name=f"K{i}_pow_2", distribution=as_normalized_dict(_POW_TWO_SIZES))
for i in range(3)
],
FuzzedParameter(name="K0", distribution={
ParameterAlias("K0_any"): 0.5,
ParameterAlias("K0_pow_2"): 0.5,
}),
# Square matricies are somewhat common, so we sometimes
# alias K1 and K2 to other dims.
FuzzedParameter(name="K1", distribution={
ParameterAlias("K1_any"): 0.3,
ParameterAlias("K1_pow_2"): 0.4,
ParameterAlias("K0"): 0.3,
}),
FuzzedParameter(name="K2", distribution={
ParameterAlias("K2_any"): 0.3,
ParameterAlias("K2_pow_2"): 0.4,
ParameterAlias("K0"): 0.15,
ParameterAlias("K1"): 0.15,
}),
],
tensors=[
FuzzedTensor(
name="x",
size=("K0", "K1"),
probability_contiguous=1,
max_elements=1024 ** 2,
dtype=dtype,
cuda=False,
),
FuzzedTensor(
name="y",
size=("K1", "K2"),
probability_contiguous=1,
max_elements=1024 ** 2,
dtype=dtype,
cuda=False,
),
],
constraints=[
lambda params: params["K0"] * params["K1"] * params["K2"] < 1024**3
],
seed=seed,
)
def _subprocess_main(seed, result_file=None):
for tensors, tensor_params, params in ConvFuzzer(seed).take(_NUM_PER_SUBPROCESS):
model = torch.nn.Conv2d(
params["C"], params["out_channels"],
(params["kernel_H"], params["kernel_W"]),
stride=(params["stride_H"], params["stride_W"]),
# TODO(robieta): padding and groups,
padding=(0, 0), groups=1, bias=True
)
model.eval()
timer = Timer(
stmt="model(x)",
globals={
"model": model,
"x": tensors["x"],
},
)
torch._C._set_mkldnn_enabled(True)
mkl_time = timer.blocked_autorange(min_run_time=_MIN_RUN_TIME)
torch._C._set_mkldnn_enabled(False)
native_time = timer.blocked_autorange(min_run_time=_MIN_RUN_TIME)
torch._C._set_mkldnn_enabled(True)
if result_file:
with open(result_file, "ab") as f:
pickle.dump({
"op": "Conv",
"params": params,
"tensor_params": tensor_params,
"mkl_time": mkl_time,
"native_time": native_time,
}, f)
for tensors, tensor_params, params in MatMulFuzzer(seed).take(_NUM_PER_SUBPROCESS):
timer = Timer(
stmt="torch.matmul(x, y)",
globals=tensors,
)
t = timer.blocked_autorange(min_run_time=_MIN_RUN_TIME)
if result_file:
with open(result_file, "ab") as f:
pickle.dump({
"op": "MatMul",
"params": params,
"tensor_params": tensor_params,
"time": t,
}, f)
def run_subprocess(seed):
try:
core = _CORE_POOL.get()
_, result_file = tempfile.mkstemp(suffix=".pkl")
subprocess.run(
f"taskset --cpu-list {core} "
"python fuzz_conv.py "
"--DETAIL_in_subprocess "
f"--DETAIL_seed {seed} "
f"--DETAIL_result_file {result_file}",
env={
"PATH": os.getenv("PATH"),
"PYTHONPATH": os.getenv("PYTHONPATH") or "",
"OMP_NUM_THREADS": "1",
"MKL_NUM_THREADS": "1",
"NUMEXPR_NUM_THREADS": "1",
},
stdout=subprocess.PIPE,
shell=True
)
with _RESULT_FILE_LOCK, \
open(result_file, "rb") as f_in, \
open(_RESULT_FILE, "ab") as f_out:
f_out.write(f_in.read())
except KeyboardInterrupt:
pass # Handle ctrl-c gracefully.
finally:
_CORE_POOL.put(core)
if os.path.exists(result_file):
os.remove(result_file)
def main():
n = 10000
workers = int(_POOL_SIZE)
with open(_RESULT_FILE, "wb"):
pass
with multiprocessing.dummy.Pool(workers) as pool:
start_time = time.time()
for i, r in enumerate(pool.imap(run_subprocess, range(n))):
n_trials_done = (i + 1) * _NUM_PER_SUBPROCESS
time_per_result = (time.time() - start_time) / n_trials_done
eta = int((n * _NUM_PER_SUBPROCESS - n_trials_done) * time_per_result)
print(f"\r{i + 1} / {n} ETA:{datetime.timedelta(seconds=eta)}".ljust(80), end="")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--DETAIL_in_subprocess", action="store_true")
parser.add_argument("--DETAIL_result_file", type=str, default=None)
parser.add_argument("--DETAIL_seed", type=int, default=None)
args = parser.parse_args()
if args.DETAIL_in_subprocess:
try:
_subprocess_main(args.DETAIL_seed, args.DETAIL_result_file)
except KeyboardInterrupt:
pass # Handle ctrl-c gracefully.
else:
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment