Skip to content

Instantly share code, notes, and snippets.

@hccho2
Created December 4, 2020 10:39
Show Gist options
  • Save hccho2/3536a97021b0aad027f5f0342097971c to your computer and use it in GitHub Desktop.
Save hccho2/3536a97021b0aad027f5f0342097971c to your computer and use it in GitHub Desktop.
lr = 0.1
model = nn.Linear(10,1)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
lambda1 = lambda epoch: epoch/10 # lr * lambda1(epoch+1)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,lambda1)
print(optimizer.state_dict())
for epoch in range(5):
optimizer.step()
scheduler.step()
print(epoch, lr * (epoch+1)/10 , optimizer.state_dict()['param_groups'][0]['lr']) # optimizer.state_dict()['param_groups']가 길이 1짜리 list
@hccho2
Copy link
Author

hccho2 commented Dec 4, 2020

lr = 0.1
model1 = nn.Linear(10,1)
model2 = nn.Linear(20,1)


params =[{'params': model1.parameters()}, {'params': model2.parameters()}]

optimizer = torch.optim.Adam(params, lr=lr)

lambda1 = lambda epoch: epoch/10   # lr * lambda1(epoch+1) 
lambda2 = lambda epoch: 0.95 ** epoch
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda = [lambda1,lambda2])


print(optimizer.state_dict())

for epoch in range(5):
    
    optimizer.step()
    scheduler.step()
    
    print(epoch, lr * (epoch+1)/10 , optimizer.state_dict()['param_groups'][0]['lr'])  # optimizer.state_dict()['param_groups']가 길이 1짜리 list

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment