Last active
July 2, 2024 00:56
-
-
Save daviesl/3a924141f3d485e776b48768044a6633 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
# Set the seed for reproducibility | |
torch.manual_seed(42) | |
# True target distribution probabilities | |
true_probs = torch.tensor([0.2, 0.5, 0.3]) | |
def f(x): | |
# Minus sign is important to ensure we are minimizing | |
return -torch.distributions.Multinomial(total_count=1, probs=true_probs).log_prob(x) | |
n_batch = 10 | |
n_iter = 2000 | |
# Initialize the parameter estimates | |
theta_hat = torch.nn.Parameter(torch.tensor([0., 0., 0.], dtype=torch.float32)) | |
optimizer = torch.optim.SGD([theta_hat], lr=0.01) | |
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,gamma=0.997) | |
for epoch in range(n_iter): | |
optimizer.zero_grad() | |
# Sample from the estimated distribution | |
x_sample = torch.distributions.Multinomial(1, logits=theta_hat).sample((n_batch,)) | |
log_p_theta_x = torch.distributions.Multinomial(1, logits=theta_hat).log_prob(x_sample) | |
# Calculate the target function | |
f_hat = f(x_sample) | |
# Compute the gradient of the log probability with respect to parameters | |
# The `grad_outputs` should multiply the `f_hat` with the gradient directly | |
grad_log_p_theta_x = torch.autograd.grad(outputs=log_p_theta_x, inputs=theta_hat, | |
grad_outputs=torch.ones_like(log_p_theta_x), | |
create_graph=True)[0] | |
# Compute the final gradients to be used in the optimizer | |
final_gradients = (f_hat.detach().unsqueeze(1) * grad_log_p_theta_x).mean(dim=0) | |
theta_hat.grad = final_gradients | |
optimizer.step() | |
scheduler.step() | |
if epoch % 10 == 0: | |
print(f"Epoch {epoch}, Estimated Probs: {torch.softmax(theta_hat, dim=0).detach().numpy()}") | |
# Display the final estimated probabilities | |
estimated_final_probs = torch.softmax(theta_hat, dim=0) | |
print(f"Final Estimated Probabilities: {estimated_final_probs.detach().numpy()}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment