Created
October 26, 2023 02:28
-
-
Save kristian-georgiev/a1bbea994d6e8c85e4d58037245401be to your computer and use it in GitHub Desktop.
LDS snippets
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_margin(model, x_0s, tstep, iters, ddim=False, states=None): | |
if states is not None: | |
# MS COCO | |
i = np.random.randint(len(states[0])) | |
kwargs = {'encoder_hidden_states': states[:, i].clone(), 'return_dict': False} | |
else: | |
# CIFAR-10 | |
kwargs = {'return_dict': False} | |
noise_scheduler = DDPMScheduler() | |
margins = torch.zeros(x_0s.shape[0]).cuda() | |
for step in range(iters): | |
generator = torch.Generator(device='cuda') | |
generator.manual_seed(step) | |
t = torch.tensor(tstep) | |
x_0 = x_0s[:, t].cuda() | |
noise = torch.randn(x_0.shape, device=x_0.device, generator=generator) | |
noisy_latent = noise_scheduler.add_noise(x_0, noise, t) | |
with torch.no_grad(): | |
noise_pred = model(noisy_latent, t.cuda(), **kwargs)[0] | |
margins += F.mse_loss(noise_pred, noise, reduction='none').mean(dim=[1, 2, 3]).data | |
return (margins / iters).cpu().numpy() | |
def eval_correlations(scores=None, masks=None, margins=None, preds=None): | |
val_inds = np.arange(100) # hardcoding for the first 100 validation samples | |
if preds is None: | |
preds = masks @ scores | |
rs = [] | |
ps = [] | |
for ind, j in tqdm(enumerate(val_inds)): | |
r, p = spearmanr(preds[:, ind], margins[:, j]) | |
rs.append(r) | |
ps.append(p) | |
rs, ps = np.array(rs), np.array(ps) | |
print(f'Correlation: {rs.mean():.3f} (avg p value {ps.mean():.6f})') | |
return rs | |
def eval_lds(masks, # shape=[num masks, num train samples] | |
margins, # shape=[num models per mask, num masks, num val samples] | |
TRAK_scores): # shape=[num train samples, num val samples] | |
LDS_hist = eval_correlations(scores=TRAK_scores, masks=masks, margins=margins.mean(axis=0)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment