Skip to content

Instantly share code, notes, and snippets.

@samedii
Created April 12, 2024 21:04
Show Gist options
  • Save samedii/49ff3f014dbbd80c685529957e285e24 to your computer and use it in GitHub Desktop.
Save samedii/49ff3f014dbbd80c685529957e285e24 to your computer and use it in GitHub Desktop.
instance_loss = (
F.mse_loss(
predictions.predicted_noise,
diffused_latent_images.noise,
reduction="none",
)
.flatten(start_dim=1)
.mean(dim=1)
)
good_loss, bad_loss = instance_loss.chunk(2, dim=0)
with torch.no_grad():
reference_instance_loss = (
F.mse_loss(
reference_predictions.predicted_noise,
diffused_latent_images.noise,
reduction="none",
)
.flatten(start_dim=1)
.mean(dim=1)
)
reference_good_loss, reference_bad_loss = reference_instance_loss.chunk(
2, dim=0
)
reference_weight = (
(instance_loss.neg() - reference_instance_loss.neg())
.mean()
.clamp(min=0)
)
beta_dpo = 5000
loss = (
-F.sigmoid(
beta_dpo * (good_loss.neg() - reference_good_loss.neg())
- reference_weight
).mean()
- F.sigmoid(
-beta_dpo * (bad_loss.neg() - reference_bad_loss.neg())
- reference_weight
).mean()
)
return loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment