Skip to content

Instantly share code, notes, and snippets.

@buttercutter
Created November 6, 2022 03:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save buttercutter/e50eb50f6f915b46705180faf5746016 to your computer and use it in GitHub Desktop.
Save buttercutter/e50eb50f6f915b46705180faf5746016 to your computer and use it in GitHub Desktop.
# Reused from https://github.com/pinellolab/DNA-Diffusion/blob/codebase/src/data/sequence_dataloader.py
import pandas as pd
import numpy as np
import torch
import torchvision.transforms as T
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader
class SequenceDatasetBase(Dataset):
def __init__(self, data_path, sequence_length=200, sequence_encoding="polar", sequence_transform=None, cell_type_transform=None):
super().__init__()
self.data = pd.read_csv(data_path, sep="\t")
self.sequence_length = sequence_length
self.sequence_encoding = sequence_encoding
self.sequence_transform = sequence_transform
self.cell_type_transform = cell_type_transform
self.alphabet = ["A", "C", "T", "G"]
self.check_data_validity()
def __len__(self):
return len(self.data)
def __getitem__(self, index):
# Iterating through DNA sequences from dataset and one-hot encoding all nucleotides
current_seq = self.data["raw_sequence"][index]
if 'N' not in current_seq:
X_seq = self.encode_sequence(current_seq, encoding=self.sequence_encoding)
# Reading cell component at current index
X_cell_type = self.data["component"][index]
if self.sequence_transform is not None:
X_seq = self.sequence_transform(X_seq)
if self.cell_type_transform is not None:
X_cell_type = self.cell_type_transform(X_cell_type)
return X_seq, X_cell_type
def check_data_validity(self):
"""
Checks if the data is valid.
"""
if not set("".join(self.data["raw_sequence"])).issubset(set(self.alphabet)):
raise ValueError(f"Sequence contains invalid characters.")
uniq_raw_seq_len = self.data["raw_sequence"].str.len().unique()
if len(uniq_raw_seq_len) != 1 or uniq_raw_seq_len[0] != self.sequence_length:
raise ValueError(f"The sequence length does not match the data.")
def encode_sequence(self, seq, encoding):
"""
Encodes a sequence using the given encoding scheme ("polar", "onehot", "ordinal").
"""
if encoding == "polar":
seq = self.one_hot_encode(seq).T
seq[seq == 0] = -1
elif encoding == "onehot":
seq = self.one_hot_encode(seq).T
elif encoding == "ordinal":
seq = np.array([self.alphabet.index(n) for n in seq])
else:
raise ValueError(f"Unknown encoding scheme: {encoding}")
return seq
# Function for one hot encoding each line of the sequence dataset
def one_hot_encode(self, seq):
"""
One-hot encoding a sequence
"""
seq_len = len(seq)
seq_array = np.zeros((self.sequence_length, len(self.alphabet)))
for i in range(seq_len):
seq_array[i, self.alphabet.index(seq[i])] = 1
return seq_array
class SequenceDatasetTrain(SequenceDatasetBase):
def __init__(self, data_path="", **kwargs):
super().__init__(data_path=data_path, **kwargs)
class SequenceDatasetValidation(SequenceDatasetBase):
def __init__(self, data_path="", **kwargs):
super().__init__(data_path=data_path, **kwargs)
class SequenceDatasetTest(SequenceDatasetBase):
def __init__(self, data_path="", **kwargs):
super().__init__(data_path=data_path, **kwargs)
class SequenceDataModule(pl.LightningDataModule):
def __init__(
self,
train_path=None,
val_path=None,
test_path=None,
sequence_length=200,
sequence_encoding="polar",
sequence_transform=None,
cell_type_transform=None,
batch_size=None,
num_workers=1
):
super().__init__()
self.datasets = dict()
self.train_dataloader, self.val_dataloader, self.test_dataloader = None, None, None
if train_path:
self.datasets["train"] = train_path
self.train_dataloader = self._train_dataloader
if val_path:
self.datasets["validation"] = val_path
self.val_dataloader = self._val_dataloader
if test_path:
self.datasets["test"] = test_path
self.test_dataloader = self._test_dataloader
self.sequence_length = sequence_length
self.sequence_encoding = sequence_encoding
self.sequence_transform = sequence_transform
self.cell_type_transform = cell_type_transform
self.batch_size = batch_size
self.num_workers = num_workers
def setup(self):
if "train" in self.datasets:
self.train_data = SequenceDatasetTrain(
data_path=self.datasets["train"],
sequence_length=self.sequence_length,
sequence_encoding=self.sequence_encoding,
sequence_transform=self.sequence_transform,
cell_type_transform=self.cell_type_transform
)
if "validation" in self.datasets:
self.val_data = SequenceDatasetValidation(
data_path=self.datasets["validation"],
sequence_length=self.sequence_length,
sequence_encoding=self.sequence_encoding,
sequence_transform=self.sequence_transform,
cell_type_transform=self.cell_type_transform
)
if "test" in self.datasets:
self.test_data = SequenceDatasetTest(
data_path=self.datasets["test"],
sequence_length=self.sequence_length,
sequence_encoding=self.sequence_encoding,
sequence_transform=self.sequence_transform,
cell_type_transform=self.cell_type_transform
)
def _train_dataloader(self):
return DataLoader(self.train_data,
self.batch_size,
shuffle=True,
num_workers=self.num_workers,
pin_memory=True)
def _val_dataloader(self):
return DataLoader(self.val_data,
self.batch_size,
shuffle=True,
num_workers=self.num_workers,
pin_memory=True)
def _test_dataloader(self):
return DataLoader(self.test_data,
self.batch_size,
shuffle=True,
num_workers=self.num_workers,
pin_memory=True)
import imageio
import numpy as np
import torch
import torch.nn as nn
from torchvision.utils import make_grid
from torch.autograd import Variable
from torch.autograd import grad as torch_grad
import sequence_dataloader
'''
https://www.biorxiv.org/content/10.1101/2022.07.26.501466v1.full.pdf#page=10
Generative Adversarial Network
To train a GAN model, we used Wasserstein GAN architecture with gradient penalty similar to earlier work.
The model consists of two parts; generator and discriminator. Generator takes noise as input (size is 128),
followed by a dense layer with 64,000 (500 * 128) units with ELU activation, a reshape layer (500, 128),
a convolution tower of 5 convolution blocks with skip connections,
a 1D convolution layer with 4 filters with kernel width 1, and finally a SOFTMAX activation layer.
The output of the generator is a 500 × 4 matrix, which represents one-hot encoded DNA sequence.
Discriminator takes 500 bp one-hot encoded DNA sequence as input (real or fake),
followed by a 1D convolution layer with 128 filters with kernel width 1,
a convolution tower of 5 convolution blocks with skip connections, a flatten layer,
and finally a dense layer with 1 unit.
Each block in the convolution tower consists of a RELU activation layer
followed by 1D convolution with 128 filters with kernel width 5.
The noise is generated by the numpy.random.normal(0, 1, (batch_size, 128)) command. We used a batch size of 128.
For every train_on_batch iteration of the generator, we performed 10 train_on_batch iteration for the discriminator.
We used Adam optimizer with learning_rate of 0.0001, beta_1 of 0.5, and beta_2 of 0.9.
We trained the models for around 260,000 batch training iteration for KC and
around 160,000 batch training iteration for MEL.
'''
BATCH_SIZE = 260000
SIZE_OF_INPUT = 128
SIZE_OF_FEATURE_MAP = 128
DNA_BP = 500
SIZE_OF_HIDDEN_LAYERS = DNA_BP*SIZE_OF_INPUT
NUM_OF_1D_CONV_FILTERS = 4
NUM_OF_CONV_2D = 5
NUM_EPOCHS = 6000
LEARNING_RATE = 0.7
MOMENTUM = 0.9
USE_CUDA = torch.cuda.is_available()
'''
Commands to obtain meulerman's small dna dataset
wget https://www.dropbox.com/s/db6up7c0d4jwdp4/train_all_classifier_WM20220916.csv.gz?dl=2
mv train_all_classifier_WM20220916.csv.gz?dl=2 train_all_classifier_WM20220916.csv.gz
gunzip -d ./train_all_classifier_WM20220916.csv.gz train_all_classifier_WM20220916.csv
'''
encode_data = sequence_dataloader.SequenceDatasetBase(data_path="./train_all_classifier_WM20220916.csv",
sequence_length=200, sequence_encoding="polar",
sequence_transform=None, cell_type_transform=None)
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.conv_2d = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1)
)
self.linears = nn.Sequential(
nn.Linear(SIZE_OF_INPUT, SIZE_OF_HIDDEN_LAYERS),
nn.ReLU(), # replace ELU with RELU
nn.Dropout()
)
self.skip_array = torch.zeros([NUM_OF_CONV_2D, SIZE_OF_FEATURE_MAP]) # for storing the skip connection values
# for storing the output of the generator is a 500 × 4 matrix, which represents one-hot encoded DNA sequence.
self.generator_output = torch.zeros(DNA_BP, SIZE_OF_INPUT)
def forward(self, x):
dense_output = self.linears(x)
conv_2d_input = torch.reshape(dense_output, (DNA_BP, SIZE_OF_INPUT))
for i in range(NUM_OF_CONV_2D): # a convolution tower of 5 convolution blocks with skip connections
# convolution block
conv_2d_output = self.conv_2d(conv_2d_input)
if i == 0:
self.skip_array[i] = conv_2d_input
elif i == (NUM_OF_CONV_2D-1):
conv_2d_output = self.skip_array[i] + conv_2d_output
else:
self.skip_array[i] = self.skip_array[i] + conv_2d_output
# a 1D convolution layer with 4 filters with kernel width 1,
# probably 4 combinations of stride, padding and dilation
conv_1d_a = nn.Conv1d(in_channels=1, out_channels=DNA_BP*NUM_OF_1D_CONV_FILTERS, kernel_size=1,
stride=1, padding=1, dilation=1)
conv_1d_b = nn.Conv1d(in_channels=1, out_channels=DNA_BP*NUM_OF_1D_CONV_FILTERS, kernel_size=1,
stride=1, padding=0, dilation=1)
conv_1d_c = nn.Conv1d(in_channels=1, out_channels=DNA_BP*NUM_OF_1D_CONV_FILTERS, kernel_size=1,
stride=1, padding=1, dilation=2)
conv_1d_d = nn.Conv1d(in_channels=1, out_channels=DNA_BP*NUM_OF_1D_CONV_FILTERS, kernel_size=1,
stride=1, padding=0, dilation=2)
conv_1d_a_output = conv_1d_a(conv_2d_output)
conv_1d_b_output = conv_1d_b(conv_2d_output)
conv_1d_c_output = conv_1d_c(conv_2d_output)
conv_1d_d_output = conv_1d_d(conv_2d_output)
self.generator_output += conv_1d_a_output + conv_1d_b_output + conv_1d_c_output + conv_1d_d_output
# SOFTMAX activation layer
smax = nn.Softmax(dim=1)
self.generator_output = smax(self.generator_output)
return self.generator_output
# See https://zhuanlan.zhihu.com/p/25071913 for a chinese explanation on WGAN-GP
# Reused from https://github.com/EmilienDupont/wgan-gp/blob/ef82364f2a2ec452a52fbf4a739f95039ae76fe3/training.py
class Trainer:
def __init__(self, generator, discriminator, gen_optimizer, dis_optimizer,
gp_weight=10, critic_iterations=5, print_every=50,
use_cuda=False):
self.G = generator
self.G_opt = gen_optimizer
self.D = discriminator
self.D_opt = dis_optimizer
self.losses = {'G': [], 'D': [], 'GP': [], 'gradient_norm': []}
self.num_steps = 0
self.use_cuda = use_cuda
self.gp_weight = gp_weight
self.critic_iterations = critic_iterations
self.print_every = print_every
if self.use_cuda:
self.G.cuda()
self.D.cuda()
def _critic_train_iteration(self, data):
""" """
# Get generated data
batch_size = data.size()[0]
generated_data = self.sample_generator(batch_size)
# Calculate probabilities on real and generated data
data = Variable(data)
if self.use_cuda:
data = data.cuda()
d_real = self.D(data)
d_generated = self.D(generated_data)
# Get gradient penalty
gradient_penalty = self._gradient_penalty(data, generated_data)
self.losses['GP'].append(gradient_penalty.data[0])
# Create total loss and optimize
self.D_opt.zero_grad()
d_loss = d_generated.mean() - d_real.mean() + gradient_penalty
d_loss.backward()
self.D_opt.step()
# Record loss
self.losses['D'].append(d_loss.data[0])
def _generator_train_iteration(self, data):
""" """
self.G_opt.zero_grad()
# Get generated data
batch_size = data.size()[0]
generated_data = self.sample_generator(batch_size)
# Calculate loss and optimize
d_generated = self.D(generated_data)
g_loss = - d_generated.mean()
g_loss.backward()
self.G_opt.step()
# Record loss
self.losses['G'].append(g_loss.data[0])
def _gradient_penalty(self, real_data, generated_data):
batch_size = real_data.size()[0]
# https://ai.stackexchange.com/questions/34926/why-do-we-use-a-linear-interpolation-of-fake-and-real-data-to-penalize-the-gradi
# Calculate interpolation
alpha = torch.rand(batch_size, 1, 1, 1)
alpha = alpha.expand_as(real_data)
if self.use_cuda:
alpha = alpha.cuda()
interpolated = alpha * real_data.data + (1 - alpha) * generated_data.data
interpolated = Variable(interpolated, requires_grad=True)
if self.use_cuda:
interpolated = interpolated.cuda()
# Calculate probability of interpolated examples
prob_interpolated = self.D(interpolated)
# Calculate gradients of probabilities with respect to examples
gradients = torch_grad(outputs=prob_interpolated, inputs=interpolated,
grad_outputs=torch.ones(prob_interpolated.size()).cuda()
if self.use_cuda else torch.ones(prob_interpolated.size()),
create_graph=True, retain_graph=True)[0]
# Gradients have shape (batch_size, num_channels, img_width, img_height),
# so flatten to easily take norm per example in batch
gradients = gradients.view(batch_size, -1)
self.losses['gradient_norm'].append(gradients.norm(2, dim=1).mean().data[0])
# Derivatives of the gradient close to 0 can cause problems because of
# the square root, so manually calculate norm and add epsilon
gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12)
# Return gradient penalty
return self.gp_weight * ((gradients_norm - 1) ** 2).mean()
def _train_epoch(self, data_loader):
for i, data in enumerate(data_loader):
self.num_steps += 1
self._critic_train_iteration(data[0])
# Only update generator every |critic_iterations| iterations
if self.num_steps % self.critic_iterations == 0:
self._generator_train_iteration(data[0])
if i % self.print_every == 0:
print("Iteration {}".format(i + 1))
print("D: {}".format(self.losses['D'][-1]))
print("GP: {}".format(self.losses['GP'][-1]))
print("Gradient norm: {}".format(self.losses['gradient_norm'][-1]))
if self.num_steps > self.critic_iterations:
print("G: {}".format(self.losses['G'][-1]))
def train(self, data_loader, epochs, save_training_gif=True):
if save_training_gif:
# Fix latents to see how image generation improves during training
fixed_latents = Variable(self.G.sample_latent(64))
if self.use_cuda:
fixed_latents = fixed_latents.cuda()
training_progress_images = []
for epoch in range(epochs):
print("\nEpoch {}".format(epoch + 1))
self._train_epoch(data_loader)
if save_training_gif:
# Generate batch of images and convert to grid
img_grid = make_grid(self.G(fixed_latents).cpu().data)
# Convert to numpy and transpose axes to fit imageio convention
# i.e. (width, height, channels)
img_grid = np.transpose(img_grid.numpy(), (1, 2, 0))
# Add image grid to training progress
training_progress_images.append(img_grid)
if save_training_gif:
imageio.mimsave('./training_{}_epochs.gif'.format(epochs),
training_progress_images)
def sample_generator(self, num_samples):
latent_samples = Variable(self.G.sample_latent(num_samples))
if self.use_cuda:
latent_samples = latent_samples.cuda()
generated_data = self.G(latent_samples)
return generated_data
def sample(self, num_samples):
generated_data = self.sample_generator(num_samples)
# Remove color channel
return generated_data.data.cpu().numpy()[:, 0, :, :]
def print_hi(name):
# Use a breakpoint in the code line below to debug your script.
print(f'Hi, {name}') # Press ⌘F8 to toggle the breakpoint.
# Press the green button in the gutter to run the script.
if __name__ == '__main__':
print_hi('PyCharm')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment