Skip to content

Instantly share code, notes, and snippets.

@ArturNiederfahrenhorst
Last active May 25, 2023 17:57
Show Gist options
  • Save ArturNiederfahrenhorst/8d041afc3b2de4f6da0f3e63c39d21c6 to your computer and use it in GitHub Desktop.
Save ArturNiederfahrenhorst/8d041afc3b2de4f6da0f3e63c39d21c6 to your computer and use it in GitHub Desktop.
Script that demonstrates torch compile speed/slowness on our modules
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.core.rl_module.torch.torch_rl_module import TorchCompileConfig
from ray.rllib.models.catalog import MODEL_DEFAULTS
import numpy as np
import gymnasium as gym
import torch
import torch._dynamo as dynamo
torch.set_float32_matmul_precision('high')
# Fake CartPole episode of n time steps.
FAKE_BATCH = {
SampleBatch.OBS: np.array(
[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]],
dtype=np.float32,
),
SampleBatch.NEXT_OBS: np.array(
[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]],
dtype=np.float32,
),
SampleBatch.ACTIONS: np.array([0, 1, 1]),
SampleBatch.PREV_ACTIONS: np.array([0, 1, 1]),
SampleBatch.REWARDS: np.array([1.0, -1.0, 0.5], dtype=np.float32),
SampleBatch.PREV_REWARDS: np.array([1.0, -1.0, 0.5], dtype=np.float32),
SampleBatch.TERMINATEDS: np.array([False, False, True]),
SampleBatch.TRUNCATEDS: np.array([False, False, False]),
SampleBatch.VF_PREDS: np.array([0.5, 0.6, 0.7], dtype=np.float32),
SampleBatch.ACTION_DIST_INPUTS: np.array(
[[-2.0, 0.5], [-3.0, -0.3], [-0.1, 2.5]], dtype=np.float32
),
SampleBatch.ACTION_LOGP: np.array([-0.5, -0.1, -0.2], dtype=np.float32),
SampleBatch.EPS_ID: np.array([0, 0, 0]),
SampleBatch.AGENT_INDEX: np.array([0, 0, 0]),
}
model_cfg = MODEL_DEFAULTS.copy()
env = gym.make("CartPole-v1")
spec = SingleAgentRLModuleSpec(
module_class=PPOTorchRLModule,
observation_space=env.observation_space,
action_space=env.action_space,
catalog_class=PPOCatalog,
model_config_dict=model_cfg
)
eager_module = spec.build().to(0)
compiled_module = spec.build().to(0)
compile_config = TorchCompileConfig(
torch_dynamo_mode="max-autotune",
)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
# Compile only one of the modules.
compiled_module = compiled_module.compile(compile_config)
end.record()
torch.cuda.synchronize()
print("Time to compile: ", start.elapsed_time(end) / 1000)
N_ITERS = 10
def timed(fn):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
with torch.no_grad():
result = fn()
end.record()
torch.cuda.synchronize()
return result, start.elapsed_time(end) / 1000
# Evaluation
def evaluate(mod, inp):
return mod.forward_exploration(inp)
# Generates random input and targets data for the model, where `b` is
# batch size.
def generate_data():
return {n: torch.Tensor(t).to(torch.float32).to(0) for n, t in FAKE_BATCH.items()}
def generate_target():
return torch.Tensor(np.array([[-1.0, 2.5], [-1.0, -2.3], [-1, 2.5]])).to(torch.float32).to(0)
inp = generate_data()
print("Eager warmup:", timed(lambda: evaluate(eager_module, inp))[1])
print("Compile warmup:", timed(lambda: evaluate(compiled_module, inp))[1])
eager_times = []
compile_times = []
print("~" * 10)
for i in range(N_ITERS):
inp = generate_data()
_, eager_time = timed(lambda: evaluate(eager_module, inp))
eager_times.append(eager_time)
print(f"eager eval time {i}: {eager_time}")
print("~" * 10)
for i in range(N_ITERS):
inp = generate_data()
_, compile_time = timed(lambda: evaluate(compiled_module, inp))
compile_times.append(compile_time)
print(f"compile eval time {i}: {compile_time}")
eager_med = np.median(eager_times)
eager_mean = np.mean(eager_times)
compile_med = np.median(compile_times)
compile_mean = np.mean(compile_times)
speedup_median = eager_med / compile_med
speedup_mean = eager_mean / compile_mean
print(f"(eval) eager median: {eager_med}, compile median: {compile_med}, speedup median: {speedup_median}x")
print(f"(eval) eager mean: {eager_mean}, compile mean: {compile_mean}, speedup mean: {speedup_mean}x")
print("~" * 10)
print(f"Torch dynamo explain output to check if graph has any breaks:")
import torch._dynamo as dynamo
dynamo_explanation = dynamo.explain(
eager_module._forward_exploration, generate_data())
print(dynamo_explanation[5])
print("~" * 10)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment