Skip to content

Instantly share code, notes, and snippets.

@buttercutter
Last active July 18, 2024 03:24
Show Gist options
  • Save buttercutter/7b99cfff0a2d48f9c8befeadee6dd54e to your computer and use it in GitHub Desktop.
Save buttercutter/7b99cfff0a2d48f9c8befeadee6dd54e to your computer and use it in GitHub Desktop.
A simple code for [Protein Discovery with Discrete Walk-Jump Sampling](http://arxiv.org/abs/2306.12360)
# Credit : gpt-4o
# Reference : [Protein Discovery with Discrete Walk-Jump Sampling](http://arxiv.org/abs/2306.12360)
import torch
import torch.nn as nn
import torch.optim as optim
import math
ebm_energy_regularization_scale = 0.1 # for L2 regularization on EBM loss
def log_sum_exp(x):
max_val = x.max()
return max_val + torch.log(torch.sum(torch.exp(x - max_val)))
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
return x + self.pe[:x.size(0), :]
class TransformerDenoiser(nn.Module):
def __init__(self, input_dim, model_dim, num_layers, num_heads):
super(TransformerDenoiser, self).__init__()
self.embedding = nn.Linear(input_dim, model_dim)
self.pos_encoder = PositionalEncoding(model_dim)
encoder_layers = nn.TransformerEncoderLayer(model_dim, num_heads, model_dim, batch_first=True)
self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers)
self.denoise_head = nn.Linear(model_dim, input_dim)
def forward(self, src):
src = src.unsqueeze(1) # Add sequence length dimension
src = self.embedding(src) * math.sqrt(self.embedding.weight.size(1))
src = self.pos_encoder(src)
output = self.transformer_encoder(src)
output = self.denoise_head(output)
output = output.squeeze(1) # Remove sequence length dimension
return output
class EnergyBasedModel(nn.Module):
def __init__(self, input_dim, hidden_dim):
super(EnergyBasedModel, self).__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
def forward(self, x):
return self.net(x)
def langevin_mcmc_step(y, model, step_size):
y.requires_grad_(True)
energy = model(y).sum()
energy.backward()
grad = y.grad
y_next = y - step_size * grad + torch.sqrt(torch.tensor(2 * step_size, device=y.device)) * torch.randn_like(y)
return y_next.detach()
def walk_jump_sampling(init_y, ebm, denoiser, num_walk_steps, walk_step_size, num_jump_steps, noise_schedule):
y = init_y.clone()
for _ in range(num_walk_steps):
y = langevin_mcmc_step(y, ebm, walk_step_size)
for t in range(num_jump_steps):
sigma_t = noise_schedule[t]
noise = torch.randn_like(y) * sigma_t
noisy_y = y + noise
denoised_y = denoiser(noisy_y)
y = denoised_y + sigma_t ** 2 * denoiser(noisy_y) # Update based on denoising equation
return y
def train_walk_jump(ebm, denoiser, data_loader, num_epochs, walk_step_size, num_walk_steps, num_jump_steps, sigma_max, sigma_min):
optimizer_ebm = optim.AdamW(ebm.parameters(), lr=1e-4, weight_decay=1e-5) # Added weight decay
optimizer_denoiser = optim.AdamW(denoiser.parameters(), lr=1e-4, weight_decay=1e-5) # Added weight decay
scheduler_ebm = optim.lr_scheduler.StepLR(optimizer_ebm, step_size=5, gamma=0.5)
scheduler_denoiser = optim.lr_scheduler.StepLR(optimizer_denoiser, step_size=5, gamma=0.5)
noise_schedule = torch.linspace(sigma_max, sigma_min, num_jump_steps)
for epoch in range(num_epochs):
for clean_data in data_loader:
clean_data = clean_data.to(device)
noisy_data = clean_data + torch.randn_like(clean_data) * sigma_max
# Train EBM
optimizer_ebm.zero_grad()
energy_real = ebm(noisy_data)
# Generate samples using walk-jump
init_y = torch.randn_like(noisy_data).to(device)
generated_samples = walk_jump_sampling(init_y, ebm, denoiser, num_walk_steps, walk_step_size, num_jump_steps, noise_schedule)
energy_fake = ebm(generated_samples)
# Compute EBM loss with contrastive divergence
#loss_ebm = (energy_real.mean() - energy_fake.mean()) + ebm_energy_regularization_scale * (energy_real ** 2).mean() # Added L2 regularization
# Compute EBM loss with contrastive divergence, log-sum-exp trick and offset
loss_ebm = log_sum_exp(energy_real) - log_sum_exp(energy_fake) + ebm_energy_regularization_scale * (energy_real ** 2).mean() # Added L2 regularization
loss_ebm.backward()
nn.utils.clip_grad_norm_(ebm.parameters(), max_norm=1.0) # Gradient clipping
optimizer_ebm.step()
# Train denoiser
optimizer_denoiser.zero_grad()
for t in range(num_jump_steps):
sigma_t = noise_schedule[t]
noise = torch.randn_like(clean_data) * sigma_t
noisy_y = clean_data + noise
denoised_y = denoiser(noisy_y)
if epoch == 0 and t == 0: # Print shapes only once
print(f"noisy_y.shape = {noisy_y.shape} , denoised_y.shape = {denoised_y.shape} , clean_data.shape = {clean_data.shape}")
loss_denoiser = nn.MSELoss()(denoised_y, clean_data)
loss_denoiser.backward()
nn.utils.clip_grad_norm_(denoiser.parameters(), max_norm=1.0) # Gradient clipping
optimizer_denoiser.step()
scheduler_ebm.step()
scheduler_denoiser.step()
print(f"Epoch {epoch + 1}/{num_epochs}, EBM Loss: {loss_ebm.item()}, Denoiser Loss: {loss_denoiser.item()}")
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("mps")
# Define parameters
input_dim = 512
model_dim = 512
hidden_dim = 256
num_layers = 6
num_heads = 8
num_walk_steps = 10
num_jump_steps = 5
walk_step_size = 0.01
sigma_max = 1.0
sigma_min = 0.1
num_epochs = 10
# Initialize models
ebm = EnergyBasedModel(input_dim, hidden_dim).to(device)
denoiser = TransformerDenoiser(input_dim, model_dim, num_layers, num_heads).to(device)
# Dummy data loader
data_loader = torch.utils.data.DataLoader(torch.randn(100, input_dim), batch_size=32, shuffle=True)
# Train models
train_walk_jump(ebm, denoiser, data_loader, num_epochs, walk_step_size, num_walk_steps, num_jump_steps, sigma_max, sigma_min)
# Define noise schedule
noise_schedule = torch.linspace(sigma_max, sigma_min, num_jump_steps).to(device)
# Sample using walk-jump
init_y = torch.randn(32, input_dim).to(device)
samples = walk_jump_sampling(init_y, ebm, denoiser, num_walk_steps, walk_step_size, num_jump_steps, noise_schedule)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment