-
-
Save ArturNiederfahrenhorst/8d041afc3b2de4f6da0f3e63c39d21c6 to your computer and use it in GitHub Desktop.
Script that demonstrates torch compile speed/slowness on our modules
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from ray.rllib.policy.sample_batch import SampleBatch | |
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule | |
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog | |
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec | |
from ray.rllib.core.rl_module.torch.torch_rl_module import TorchCompileConfig | |
from ray.rllib.models.catalog import MODEL_DEFAULTS | |
import numpy as np | |
import gymnasium as gym | |
import torch | |
import torch._dynamo as dynamo | |
torch.set_float32_matmul_precision('high') | |
# Fake CartPole episode of n time steps. | |
FAKE_BATCH = { | |
SampleBatch.OBS: np.array( | |
[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]], | |
dtype=np.float32, | |
), | |
SampleBatch.NEXT_OBS: np.array( | |
[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]], | |
dtype=np.float32, | |
), | |
SampleBatch.ACTIONS: np.array([0, 1, 1]), | |
SampleBatch.PREV_ACTIONS: np.array([0, 1, 1]), | |
SampleBatch.REWARDS: np.array([1.0, -1.0, 0.5], dtype=np.float32), | |
SampleBatch.PREV_REWARDS: np.array([1.0, -1.0, 0.5], dtype=np.float32), | |
SampleBatch.TERMINATEDS: np.array([False, False, True]), | |
SampleBatch.TRUNCATEDS: np.array([False, False, False]), | |
SampleBatch.VF_PREDS: np.array([0.5, 0.6, 0.7], dtype=np.float32), | |
SampleBatch.ACTION_DIST_INPUTS: np.array( | |
[[-2.0, 0.5], [-3.0, -0.3], [-0.1, 2.5]], dtype=np.float32 | |
), | |
SampleBatch.ACTION_LOGP: np.array([-0.5, -0.1, -0.2], dtype=np.float32), | |
SampleBatch.EPS_ID: np.array([0, 0, 0]), | |
SampleBatch.AGENT_INDEX: np.array([0, 0, 0]), | |
} | |
model_cfg = MODEL_DEFAULTS.copy() | |
env = gym.make("CartPole-v1") | |
spec = SingleAgentRLModuleSpec( | |
module_class=PPOTorchRLModule, | |
observation_space=env.observation_space, | |
action_space=env.action_space, | |
catalog_class=PPOCatalog, | |
model_config_dict=model_cfg | |
) | |
eager_module = spec.build().to(0) | |
compiled_module = spec.build().to(0) | |
compile_config = TorchCompileConfig( | |
torch_dynamo_mode="max-autotune", | |
) | |
start = torch.cuda.Event(enable_timing=True) | |
end = torch.cuda.Event(enable_timing=True) | |
start.record() | |
# Compile only one of the modules. | |
compiled_module = compiled_module.compile(compile_config) | |
end.record() | |
torch.cuda.synchronize() | |
print("Time to compile: ", start.elapsed_time(end) / 1000) | |
N_ITERS = 10 | |
def timed(fn): | |
start = torch.cuda.Event(enable_timing=True) | |
end = torch.cuda.Event(enable_timing=True) | |
start.record() | |
with torch.no_grad(): | |
result = fn() | |
end.record() | |
torch.cuda.synchronize() | |
return result, start.elapsed_time(end) / 1000 | |
# Evaluation | |
def evaluate(mod, inp): | |
return mod.forward_exploration(inp) | |
# Generates random input and targets data for the model, where `b` is | |
# batch size. | |
def generate_data(): | |
return {n: torch.Tensor(t).to(torch.float32).to(0) for n, t in FAKE_BATCH.items()} | |
def generate_target(): | |
return torch.Tensor(np.array([[-1.0, 2.5], [-1.0, -2.3], [-1, 2.5]])).to(torch.float32).to(0) | |
inp = generate_data() | |
print("Eager warmup:", timed(lambda: evaluate(eager_module, inp))[1]) | |
print("Compile warmup:", timed(lambda: evaluate(compiled_module, inp))[1]) | |
eager_times = [] | |
compile_times = [] | |
print("~" * 10) | |
for i in range(N_ITERS): | |
inp = generate_data() | |
_, eager_time = timed(lambda: evaluate(eager_module, inp)) | |
eager_times.append(eager_time) | |
print(f"eager eval time {i}: {eager_time}") | |
print("~" * 10) | |
for i in range(N_ITERS): | |
inp = generate_data() | |
_, compile_time = timed(lambda: evaluate(compiled_module, inp)) | |
compile_times.append(compile_time) | |
print(f"compile eval time {i}: {compile_time}") | |
eager_med = np.median(eager_times) | |
eager_mean = np.mean(eager_times) | |
compile_med = np.median(compile_times) | |
compile_mean = np.mean(compile_times) | |
speedup_median = eager_med / compile_med | |
speedup_mean = eager_mean / compile_mean | |
print(f"(eval) eager median: {eager_med}, compile median: {compile_med}, speedup median: {speedup_median}x") | |
print(f"(eval) eager mean: {eager_mean}, compile mean: {compile_mean}, speedup mean: {speedup_mean}x") | |
print("~" * 10) | |
print(f"Torch dynamo explain output to check if graph has any breaks:") | |
import torch._dynamo as dynamo | |
dynamo_explanation = dynamo.explain( | |
eager_module._forward_exploration, generate_data()) | |
print(dynamo_explanation[5]) | |
print("~" * 10) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment