Skip to content

Instantly share code, notes, and snippets.

View kristian-georgiev's full-sized avatar
🤸

Kristian Georgiev kristian-georgiev

🤸
View GitHub Profile
@kristian-georgiev
kristian-georgiev / lds.py
Created October 26, 2023 02:28
LDS snippets
def get_margin(model, x_0s, tstep, iters, ddim=False, states=None):
if states is not None:
# MS COCO
i = np.random.randint(len(states[0]))
kwargs = {'encoder_hidden_states': states[:, i].clone(), 'return_dict': False}
else:
# CIFAR-10
kwargs = {'return_dict': False}
import torch
import numpy as np
from matplotlib import pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
from robustness.tools.custom_modules import SequentialWithArgs, FakeReLU
from e2cnn import gspaces
from e2cnn import nn as enn