Skip to content

Instantly share code, notes, and snippets.

Created February 2, 2024 15:13
Show Gist options
  • Save tiandiao123/71acbcaee7968a630dd7e175a3987071 to your computer and use it in GitHub Desktop.
Save tiandiao123/71acbcaee7968a630dd7e175a3987071 to your computer and use it in GitHub Desktop.
import torch
import torch.nn.functional as F
def sample_reweight(loss_curve, loss_values, k_th, alpha1=1.0, alpha2=1.0, bins_sr=10, decay=0.9):
The SR module of Double Ensemble using PyTorch.
- loss_curve: Tensor, shape (N, T), the loss curve for each sample over training iterations.
- loss_values: Tensor, shape (N,), the loss of the current ensemble on each sample.
- k_th: int, the index of the current sub-model, starting from 1.
- alpha1: float, weight for h1 calculation.
- alpha2: float, weight for h2 calculation.
- bins_sr: int, number of bins for discretizing h-values.
- decay: float, decay rate for adjusting weights.
- weights: Tensor, shape (N,), new weights for each sample.
N, T = loss_curve.shape
# Normalize loss_curve and loss_values with ranking
loss_curve_rank = loss_curve.argsort(dim = 0).argsort(dim = 0).float() / (T - 1)
loss_values_rank = (-loss_values).argsort().argsort().float() / (N - 1)
# Calculate l_start and l_end
part = max(int(T * 0.1), 1)
l_start = loss_curve_rank[:, :part].mean(dim=1)
l_end = loss_curve_rank[:, -part:].mean(dim=1)
# Calculate h-value for each sample
h1 = loss_values_rank
h2 = (l_end / l_start).argsort().argsort().float() / (N - 1)
h_value = alpha1 * h1 + alpha2 * h2
# Discretize h-value into bins and calculate weights
_, bins = torch.histogram(h_value, bins=bins_sr)
h_bins = torch.bucketize(h_value, bins, right=True)
weights = torch.zeros(N, dtype=torch.float)
for i in range(1, bins_sr + 1):
bin_mask = h_bins == i
if bin_mask.any():
bin_mean_h = h_value[bin_mask].mean()
weights[bin_mask] = 1.0 / (decay ** k_th * bin_mean_h + 0.1)
return weights
# Example usage
N, T = 100, 20
loss_curve = torch.randn(N, T)
loss_values = torch.randn(N)
k_th = 1
weights = sample_reweight(loss_curve, loss_values, k_th)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment