Skip to content

Instantly share code, notes, and snippets.

@kristian-georgiev
Created October 26, 2023 02:28
Show Gist options
  • Save kristian-georgiev/a1bbea994d6e8c85e4d58037245401be to your computer and use it in GitHub Desktop.
Save kristian-georgiev/a1bbea994d6e8c85e4d58037245401be to your computer and use it in GitHub Desktop.
LDS snippets
def get_margin(model, x_0s, tstep, iters, ddim=False, states=None):
if states is not None:
# MS COCO
i = np.random.randint(len(states[0]))
kwargs = {'encoder_hidden_states': states[:, i].clone(), 'return_dict': False}
else:
# CIFAR-10
kwargs = {'return_dict': False}
noise_scheduler = DDPMScheduler()
margins = torch.zeros(x_0s.shape[0]).cuda()
for step in range(iters):
generator = torch.Generator(device='cuda')
generator.manual_seed(step)
t = torch.tensor(tstep)
x_0 = x_0s[:, t].cuda()
noise = torch.randn(x_0.shape, device=x_0.device, generator=generator)
noisy_latent = noise_scheduler.add_noise(x_0, noise, t)
with torch.no_grad():
noise_pred = model(noisy_latent, t.cuda(), **kwargs)[0]
margins += F.mse_loss(noise_pred, noise, reduction='none').mean(dim=[1, 2, 3]).data
return (margins / iters).cpu().numpy()
def eval_correlations(scores=None, masks=None, margins=None, preds=None):
val_inds = np.arange(100) # hardcoding for the first 100 validation samples
if preds is None:
preds = masks @ scores
rs = []
ps = []
for ind, j in tqdm(enumerate(val_inds)):
r, p = spearmanr(preds[:, ind], margins[:, j])
rs.append(r)
ps.append(p)
rs, ps = np.array(rs), np.array(ps)
print(f'Correlation: {rs.mean():.3f} (avg p value {ps.mean():.6f})')
return rs
def eval_lds(masks, # shape=[num masks, num train samples]
margins, # shape=[num models per mask, num masks, num val samples]
TRAK_scores): # shape=[num train samples, num val samples]
LDS_hist = eval_correlations(scores=TRAK_scores, masks=masks, margins=margins.mean(axis=0))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment