Last active
July 18, 2024 03:24
-
-
Save buttercutter/7b99cfff0a2d48f9c8befeadee6dd54e to your computer and use it in GitHub Desktop.
A simple code for [Protein Discovery with Discrete Walk-Jump Sampling](http://arxiv.org/abs/2306.12360)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Credit : gpt-4o | |
# Reference : [Protein Discovery with Discrete Walk-Jump Sampling](http://arxiv.org/abs/2306.12360) | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import math | |
ebm_energy_regularization_scale = 0.1 # for L2 regularization on EBM loss | |
def log_sum_exp(x): | |
max_val = x.max() | |
return max_val + torch.log(torch.sum(torch.exp(x - max_val))) | |
class PositionalEncoding(nn.Module): | |
def __init__(self, d_model, max_len=5000): | |
super(PositionalEncoding, self).__init__() | |
pe = torch.zeros(max_len, d_model) | |
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | |
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) | |
pe[:, 0::2] = torch.sin(position * div_term) | |
pe[:, 1::2] = torch.cos(position * div_term) | |
pe = pe.unsqueeze(0).transpose(0, 1) | |
self.register_buffer('pe', pe) | |
def forward(self, x): | |
return x + self.pe[:x.size(0), :] | |
class TransformerDenoiser(nn.Module): | |
def __init__(self, input_dim, model_dim, num_layers, num_heads): | |
super(TransformerDenoiser, self).__init__() | |
self.embedding = nn.Linear(input_dim, model_dim) | |
self.pos_encoder = PositionalEncoding(model_dim) | |
encoder_layers = nn.TransformerEncoderLayer(model_dim, num_heads, model_dim, batch_first=True) | |
self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers) | |
self.denoise_head = nn.Linear(model_dim, input_dim) | |
def forward(self, src): | |
src = src.unsqueeze(1) # Add sequence length dimension | |
src = self.embedding(src) * math.sqrt(self.embedding.weight.size(1)) | |
src = self.pos_encoder(src) | |
output = self.transformer_encoder(src) | |
output = self.denoise_head(output) | |
output = output.squeeze(1) # Remove sequence length dimension | |
return output | |
class EnergyBasedModel(nn.Module): | |
def __init__(self, input_dim, hidden_dim): | |
super(EnergyBasedModel, self).__init__() | |
self.net = nn.Sequential( | |
nn.Linear(input_dim, hidden_dim), | |
nn.BatchNorm1d(hidden_dim), | |
nn.ReLU(), | |
nn.Linear(hidden_dim, hidden_dim), | |
nn.BatchNorm1d(hidden_dim), | |
nn.ReLU(), | |
nn.Linear(hidden_dim, 1) | |
) | |
def forward(self, x): | |
return self.net(x) | |
def langevin_mcmc_step(y, model, step_size): | |
y.requires_grad_(True) | |
energy = model(y).sum() | |
energy.backward() | |
grad = y.grad | |
y_next = y - step_size * grad + torch.sqrt(torch.tensor(2 * step_size, device=y.device)) * torch.randn_like(y) | |
return y_next.detach() | |
def walk_jump_sampling(init_y, ebm, denoiser, num_walk_steps, walk_step_size, num_jump_steps, noise_schedule): | |
y = init_y.clone() | |
for _ in range(num_walk_steps): | |
y = langevin_mcmc_step(y, ebm, walk_step_size) | |
for t in range(num_jump_steps): | |
sigma_t = noise_schedule[t] | |
noise = torch.randn_like(y) * sigma_t | |
noisy_y = y + noise | |
denoised_y = denoiser(noisy_y) | |
y = denoised_y + sigma_t ** 2 * denoiser(noisy_y) # Update based on denoising equation | |
return y | |
def train_walk_jump(ebm, denoiser, data_loader, num_epochs, walk_step_size, num_walk_steps, num_jump_steps, sigma_max, sigma_min): | |
optimizer_ebm = optim.AdamW(ebm.parameters(), lr=1e-4, weight_decay=1e-5) # Added weight decay | |
optimizer_denoiser = optim.AdamW(denoiser.parameters(), lr=1e-4, weight_decay=1e-5) # Added weight decay | |
scheduler_ebm = optim.lr_scheduler.StepLR(optimizer_ebm, step_size=5, gamma=0.5) | |
scheduler_denoiser = optim.lr_scheduler.StepLR(optimizer_denoiser, step_size=5, gamma=0.5) | |
noise_schedule = torch.linspace(sigma_max, sigma_min, num_jump_steps) | |
for epoch in range(num_epochs): | |
for clean_data in data_loader: | |
clean_data = clean_data.to(device) | |
noisy_data = clean_data + torch.randn_like(clean_data) * sigma_max | |
# Train EBM | |
optimizer_ebm.zero_grad() | |
energy_real = ebm(noisy_data) | |
# Generate samples using walk-jump | |
init_y = torch.randn_like(noisy_data).to(device) | |
generated_samples = walk_jump_sampling(init_y, ebm, denoiser, num_walk_steps, walk_step_size, num_jump_steps, noise_schedule) | |
energy_fake = ebm(generated_samples) | |
# Compute EBM loss with contrastive divergence | |
#loss_ebm = (energy_real.mean() - energy_fake.mean()) + ebm_energy_regularization_scale * (energy_real ** 2).mean() # Added L2 regularization | |
# Compute EBM loss with contrastive divergence, log-sum-exp trick and offset | |
loss_ebm = log_sum_exp(energy_real) - log_sum_exp(energy_fake) + ebm_energy_regularization_scale * (energy_real ** 2).mean() # Added L2 regularization | |
loss_ebm.backward() | |
nn.utils.clip_grad_norm_(ebm.parameters(), max_norm=1.0) # Gradient clipping | |
optimizer_ebm.step() | |
# Train denoiser | |
optimizer_denoiser.zero_grad() | |
for t in range(num_jump_steps): | |
sigma_t = noise_schedule[t] | |
noise = torch.randn_like(clean_data) * sigma_t | |
noisy_y = clean_data + noise | |
denoised_y = denoiser(noisy_y) | |
if epoch == 0 and t == 0: # Print shapes only once | |
print(f"noisy_y.shape = {noisy_y.shape} , denoised_y.shape = {denoised_y.shape} , clean_data.shape = {clean_data.shape}") | |
loss_denoiser = nn.MSELoss()(denoised_y, clean_data) | |
loss_denoiser.backward() | |
nn.utils.clip_grad_norm_(denoiser.parameters(), max_norm=1.0) # Gradient clipping | |
optimizer_denoiser.step() | |
scheduler_ebm.step() | |
scheduler_denoiser.step() | |
print(f"Epoch {epoch + 1}/{num_epochs}, EBM Loss: {loss_ebm.item()}, Denoiser Loss: {loss_denoiser.item()}") | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
else: | |
device = torch.device("mps") | |
# Define parameters | |
input_dim = 512 | |
model_dim = 512 | |
hidden_dim = 256 | |
num_layers = 6 | |
num_heads = 8 | |
num_walk_steps = 10 | |
num_jump_steps = 5 | |
walk_step_size = 0.01 | |
sigma_max = 1.0 | |
sigma_min = 0.1 | |
num_epochs = 10 | |
# Initialize models | |
ebm = EnergyBasedModel(input_dim, hidden_dim).to(device) | |
denoiser = TransformerDenoiser(input_dim, model_dim, num_layers, num_heads).to(device) | |
# Dummy data loader | |
data_loader = torch.utils.data.DataLoader(torch.randn(100, input_dim), batch_size=32, shuffle=True) | |
# Train models | |
train_walk_jump(ebm, denoiser, data_loader, num_epochs, walk_step_size, num_walk_steps, num_jump_steps, sigma_max, sigma_min) | |
# Define noise schedule | |
noise_schedule = torch.linspace(sigma_max, sigma_min, num_jump_steps).to(device) | |
# Sample using walk-jump | |
init_y = torch.randn(32, input_dim).to(device) | |
samples = walk_jump_sampling(init_y, ebm, denoiser, num_walk_steps, walk_step_size, num_jump_steps, noise_schedule) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment