Skip to content

Instantly share code, notes, and snippets.

@mlazos
Created May 14, 2024 17:45
Show Gist options
  • Save mlazos/b5cac7945f388cca071e95ff3963d8ea to your computer and use it in GitHub Desktop.
Save mlazos/b5cac7945f388cca071e95ff3963d8ea to your computer and use it in GitHub Desktop.
repro
import torch
from torch.optim.lr_scheduler import LambdaLR, ChainedScheduler, ConstantLR, SequentialLR
from torch.testing._internal.common_utils import CudaMemoryLeakCheck, TestCase, run_tests
def chained_fn():
from torch.testing._internal.common_utils import CudaMemoryLeakCheck
with CudaMemoryLeakCheck(None,name="hi"):
device="cuda:0"
dtype=torch.float32
optim_cls = torch.optim.ASGD
kwargs = {'lr': 0.001, 'weight_decay': 0.1, 'maximize': True, 'capturable': True, 'foreach': False}
scheduler_cls = ChainedScheduler
print(scheduler_cls)
print(kwargs)
input = torch.ones([10, 10], device=device)
model_eager = torch.nn.Sequential(
*[torch.nn.Linear(10, 10, device=device) for _ in range(1)]
)
model_eager(input).sum().backward()
opt_eager = optim_cls(model_eager.parameters(), **kwargs)
scheduler_eager = scheduler_cls(schedulers=[ConstantLR(opt_eager), ConstantLR(opt_eager)], optimizer=opt_eager)
def sequential_fn():
from torch.testing._internal.common_utils import CudaMemoryLeakCheck
with CudaMemoryLeakCheck(None,name="hi"):
device="cuda:0"
dtype=torch.float32
optim_cls = torch.optim.ASGD
kwargs = {'lr': 0.001, 'weight_decay': 0.1, 'maximize': True, 'capturable': True, 'foreach': False}
scheduler_cls = SequentialLR
print(scheduler_cls)
print(kwargs)
input = torch.ones([10, 10], device=device)
model_eager = torch.nn.Sequential(
*[torch.nn.Linear(10, 10, device=device) for _ in range(1)]
)
model_eager(input).sum().backward()
opt_eager = optim_cls(model_eager.parameters(), **kwargs)
scheduler_eager = scheduler_cls(schedulers=[ConstantLR(opt_eager), ConstantLR(opt_eager)], optimizer=opt_eager, milestones=[0])
chained_fn()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment