Skip to content

Instantly share code, notes, and snippets.

@daviesl
Last active July 2, 2024 00:56
Show Gist options
  • Save daviesl/3a924141f3d485e776b48768044a6633 to your computer and use it in GitHub Desktop.
Save daviesl/3a924141f3d485e776b48768044a6633 to your computer and use it in GitHub Desktop.
import torch
# Set the seed for reproducibility
torch.manual_seed(42)
# True target distribution probabilities
true_probs = torch.tensor([0.2, 0.5, 0.3])
def f(x):
# Minus sign is important to ensure we are minimizing
return -torch.distributions.Multinomial(total_count=1, probs=true_probs).log_prob(x)
n_batch = 10
n_iter = 2000
# Initialize the parameter estimates
theta_hat = torch.nn.Parameter(torch.tensor([0., 0., 0.], dtype=torch.float32))
optimizer = torch.optim.SGD([theta_hat], lr=0.01)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,gamma=0.997)
for epoch in range(n_iter):
optimizer.zero_grad()
# Sample from the estimated distribution
x_sample = torch.distributions.Multinomial(1, logits=theta_hat).sample((n_batch,))
log_p_theta_x = torch.distributions.Multinomial(1, logits=theta_hat).log_prob(x_sample)
# Calculate the target function
f_hat = f(x_sample)
# Compute the gradient of the log probability with respect to parameters
# The `grad_outputs` should multiply the `f_hat` with the gradient directly
grad_log_p_theta_x = torch.autograd.grad(outputs=log_p_theta_x, inputs=theta_hat,
grad_outputs=torch.ones_like(log_p_theta_x),
create_graph=True)[0]
# Compute the final gradients to be used in the optimizer
final_gradients = (f_hat.detach().unsqueeze(1) * grad_log_p_theta_x).mean(dim=0)
theta_hat.grad = final_gradients
optimizer.step()
scheduler.step()
if epoch % 10 == 0:
print(f"Epoch {epoch}, Estimated Probs: {torch.softmax(theta_hat, dim=0).detach().numpy()}")
# Display the final estimated probabilities
estimated_final_probs = torch.softmax(theta_hat, dim=0)
print(f"Final Estimated Probabilities: {estimated_final_probs.detach().numpy()}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment