Skip to content

Instantly share code, notes, and snippets.

@shunting314
Last active June 25, 2024 06:04
Show Gist options
  • Save shunting314/9f9b21d79cc5124b3b263bc59ecaeacb to your computer and use it in GitHub Desktop.
Save shunting314/9f9b21d79cc5124b3b263bc59ecaeacb to your computer and use it in GitHub Desktop.
import torch
import torch._inductor.config as inductor_config
inductor_config.optimize_scatter_upon_const_tensor = False
torch.set_default_device("cuda")
M, N = 1024, 2048
x = torch.randint(0, N, (M,), dtype=torch.int64)
@torch.compile
def f(x):
y = torch.full([M, N], 3.14, dtype=torch.float)
y.scatter_(1, x.unsqueeze(1), 2.718)
return y
f(x)
print("bye")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment