Created
November 6, 2022 03:08
-
-
Save buttercutter/e50eb50f6f915b46705180faf5746016 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Reused from https://github.com/pinellolab/DNA-Diffusion/blob/codebase/src/data/sequence_dataloader.py | |
import pandas as pd | |
import numpy as np | |
import torch | |
import torchvision.transforms as T | |
import torch.nn.functional as F | |
import pytorch_lightning as pl | |
from torch.utils.data import Dataset, DataLoader | |
class SequenceDatasetBase(Dataset): | |
def __init__(self, data_path, sequence_length=200, sequence_encoding="polar", sequence_transform=None, cell_type_transform=None): | |
super().__init__() | |
self.data = pd.read_csv(data_path, sep="\t") | |
self.sequence_length = sequence_length | |
self.sequence_encoding = sequence_encoding | |
self.sequence_transform = sequence_transform | |
self.cell_type_transform = cell_type_transform | |
self.alphabet = ["A", "C", "T", "G"] | |
self.check_data_validity() | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, index): | |
# Iterating through DNA sequences from dataset and one-hot encoding all nucleotides | |
current_seq = self.data["raw_sequence"][index] | |
if 'N' not in current_seq: | |
X_seq = self.encode_sequence(current_seq, encoding=self.sequence_encoding) | |
# Reading cell component at current index | |
X_cell_type = self.data["component"][index] | |
if self.sequence_transform is not None: | |
X_seq = self.sequence_transform(X_seq) | |
if self.cell_type_transform is not None: | |
X_cell_type = self.cell_type_transform(X_cell_type) | |
return X_seq, X_cell_type | |
def check_data_validity(self): | |
""" | |
Checks if the data is valid. | |
""" | |
if not set("".join(self.data["raw_sequence"])).issubset(set(self.alphabet)): | |
raise ValueError(f"Sequence contains invalid characters.") | |
uniq_raw_seq_len = self.data["raw_sequence"].str.len().unique() | |
if len(uniq_raw_seq_len) != 1 or uniq_raw_seq_len[0] != self.sequence_length: | |
raise ValueError(f"The sequence length does not match the data.") | |
def encode_sequence(self, seq, encoding): | |
""" | |
Encodes a sequence using the given encoding scheme ("polar", "onehot", "ordinal"). | |
""" | |
if encoding == "polar": | |
seq = self.one_hot_encode(seq).T | |
seq[seq == 0] = -1 | |
elif encoding == "onehot": | |
seq = self.one_hot_encode(seq).T | |
elif encoding == "ordinal": | |
seq = np.array([self.alphabet.index(n) for n in seq]) | |
else: | |
raise ValueError(f"Unknown encoding scheme: {encoding}") | |
return seq | |
# Function for one hot encoding each line of the sequence dataset | |
def one_hot_encode(self, seq): | |
""" | |
One-hot encoding a sequence | |
""" | |
seq_len = len(seq) | |
seq_array = np.zeros((self.sequence_length, len(self.alphabet))) | |
for i in range(seq_len): | |
seq_array[i, self.alphabet.index(seq[i])] = 1 | |
return seq_array | |
class SequenceDatasetTrain(SequenceDatasetBase): | |
def __init__(self, data_path="", **kwargs): | |
super().__init__(data_path=data_path, **kwargs) | |
class SequenceDatasetValidation(SequenceDatasetBase): | |
def __init__(self, data_path="", **kwargs): | |
super().__init__(data_path=data_path, **kwargs) | |
class SequenceDatasetTest(SequenceDatasetBase): | |
def __init__(self, data_path="", **kwargs): | |
super().__init__(data_path=data_path, **kwargs) | |
class SequenceDataModule(pl.LightningDataModule): | |
def __init__( | |
self, | |
train_path=None, | |
val_path=None, | |
test_path=None, | |
sequence_length=200, | |
sequence_encoding="polar", | |
sequence_transform=None, | |
cell_type_transform=None, | |
batch_size=None, | |
num_workers=1 | |
): | |
super().__init__() | |
self.datasets = dict() | |
self.train_dataloader, self.val_dataloader, self.test_dataloader = None, None, None | |
if train_path: | |
self.datasets["train"] = train_path | |
self.train_dataloader = self._train_dataloader | |
if val_path: | |
self.datasets["validation"] = val_path | |
self.val_dataloader = self._val_dataloader | |
if test_path: | |
self.datasets["test"] = test_path | |
self.test_dataloader = self._test_dataloader | |
self.sequence_length = sequence_length | |
self.sequence_encoding = sequence_encoding | |
self.sequence_transform = sequence_transform | |
self.cell_type_transform = cell_type_transform | |
self.batch_size = batch_size | |
self.num_workers = num_workers | |
def setup(self): | |
if "train" in self.datasets: | |
self.train_data = SequenceDatasetTrain( | |
data_path=self.datasets["train"], | |
sequence_length=self.sequence_length, | |
sequence_encoding=self.sequence_encoding, | |
sequence_transform=self.sequence_transform, | |
cell_type_transform=self.cell_type_transform | |
) | |
if "validation" in self.datasets: | |
self.val_data = SequenceDatasetValidation( | |
data_path=self.datasets["validation"], | |
sequence_length=self.sequence_length, | |
sequence_encoding=self.sequence_encoding, | |
sequence_transform=self.sequence_transform, | |
cell_type_transform=self.cell_type_transform | |
) | |
if "test" in self.datasets: | |
self.test_data = SequenceDatasetTest( | |
data_path=self.datasets["test"], | |
sequence_length=self.sequence_length, | |
sequence_encoding=self.sequence_encoding, | |
sequence_transform=self.sequence_transform, | |
cell_type_transform=self.cell_type_transform | |
) | |
def _train_dataloader(self): | |
return DataLoader(self.train_data, | |
self.batch_size, | |
shuffle=True, | |
num_workers=self.num_workers, | |
pin_memory=True) | |
def _val_dataloader(self): | |
return DataLoader(self.val_data, | |
self.batch_size, | |
shuffle=True, | |
num_workers=self.num_workers, | |
pin_memory=True) | |
def _test_dataloader(self): | |
return DataLoader(self.test_data, | |
self.batch_size, | |
shuffle=True, | |
num_workers=self.num_workers, | |
pin_memory=True) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import imageio | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from torchvision.utils import make_grid | |
from torch.autograd import Variable | |
from torch.autograd import grad as torch_grad | |
import sequence_dataloader | |
''' | |
https://www.biorxiv.org/content/10.1101/2022.07.26.501466v1.full.pdf#page=10 | |
Generative Adversarial Network | |
To train a GAN model, we used Wasserstein GAN architecture with gradient penalty similar to earlier work. | |
The model consists of two parts; generator and discriminator. Generator takes noise as input (size is 128), | |
followed by a dense layer with 64,000 (500 * 128) units with ELU activation, a reshape layer (500, 128), | |
a convolution tower of 5 convolution blocks with skip connections, | |
a 1D convolution layer with 4 filters with kernel width 1, and finally a SOFTMAX activation layer. | |
The output of the generator is a 500 × 4 matrix, which represents one-hot encoded DNA sequence. | |
Discriminator takes 500 bp one-hot encoded DNA sequence as input (real or fake), | |
followed by a 1D convolution layer with 128 filters with kernel width 1, | |
a convolution tower of 5 convolution blocks with skip connections, a flatten layer, | |
and finally a dense layer with 1 unit. | |
Each block in the convolution tower consists of a RELU activation layer | |
followed by 1D convolution with 128 filters with kernel width 5. | |
The noise is generated by the numpy.random.normal(0, 1, (batch_size, 128)) command. We used a batch size of 128. | |
For every train_on_batch iteration of the generator, we performed 10 train_on_batch iteration for the discriminator. | |
We used Adam optimizer with learning_rate of 0.0001, beta_1 of 0.5, and beta_2 of 0.9. | |
We trained the models for around 260,000 batch training iteration for KC and | |
around 160,000 batch training iteration for MEL. | |
''' | |
BATCH_SIZE = 260000 | |
SIZE_OF_INPUT = 128 | |
SIZE_OF_FEATURE_MAP = 128 | |
DNA_BP = 500 | |
SIZE_OF_HIDDEN_LAYERS = DNA_BP*SIZE_OF_INPUT | |
NUM_OF_1D_CONV_FILTERS = 4 | |
NUM_OF_CONV_2D = 5 | |
NUM_EPOCHS = 6000 | |
LEARNING_RATE = 0.7 | |
MOMENTUM = 0.9 | |
USE_CUDA = torch.cuda.is_available() | |
''' | |
Commands to obtain meulerman's small dna dataset | |
wget https://www.dropbox.com/s/db6up7c0d4jwdp4/train_all_classifier_WM20220916.csv.gz?dl=2 | |
mv train_all_classifier_WM20220916.csv.gz?dl=2 train_all_classifier_WM20220916.csv.gz | |
gunzip -d ./train_all_classifier_WM20220916.csv.gz train_all_classifier_WM20220916.csv | |
''' | |
encode_data = sequence_dataloader.SequenceDatasetBase(data_path="./train_all_classifier_WM20220916.csv", | |
sequence_length=200, sequence_encoding="polar", | |
sequence_transform=None, cell_type_transform=None) | |
class Generator(nn.Module): | |
def __init__(self): | |
super(Generator, self).__init__() | |
self.conv_2d = nn.Sequential( | |
nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1) | |
) | |
self.linears = nn.Sequential( | |
nn.Linear(SIZE_OF_INPUT, SIZE_OF_HIDDEN_LAYERS), | |
nn.ReLU(), # replace ELU with RELU | |
nn.Dropout() | |
) | |
self.skip_array = torch.zeros([NUM_OF_CONV_2D, SIZE_OF_FEATURE_MAP]) # for storing the skip connection values | |
# for storing the output of the generator is a 500 × 4 matrix, which represents one-hot encoded DNA sequence. | |
self.generator_output = torch.zeros(DNA_BP, SIZE_OF_INPUT) | |
def forward(self, x): | |
dense_output = self.linears(x) | |
conv_2d_input = torch.reshape(dense_output, (DNA_BP, SIZE_OF_INPUT)) | |
for i in range(NUM_OF_CONV_2D): # a convolution tower of 5 convolution blocks with skip connections | |
# convolution block | |
conv_2d_output = self.conv_2d(conv_2d_input) | |
if i == 0: | |
self.skip_array[i] = conv_2d_input | |
elif i == (NUM_OF_CONV_2D-1): | |
conv_2d_output = self.skip_array[i] + conv_2d_output | |
else: | |
self.skip_array[i] = self.skip_array[i] + conv_2d_output | |
# a 1D convolution layer with 4 filters with kernel width 1, | |
# probably 4 combinations of stride, padding and dilation | |
conv_1d_a = nn.Conv1d(in_channels=1, out_channels=DNA_BP*NUM_OF_1D_CONV_FILTERS, kernel_size=1, | |
stride=1, padding=1, dilation=1) | |
conv_1d_b = nn.Conv1d(in_channels=1, out_channels=DNA_BP*NUM_OF_1D_CONV_FILTERS, kernel_size=1, | |
stride=1, padding=0, dilation=1) | |
conv_1d_c = nn.Conv1d(in_channels=1, out_channels=DNA_BP*NUM_OF_1D_CONV_FILTERS, kernel_size=1, | |
stride=1, padding=1, dilation=2) | |
conv_1d_d = nn.Conv1d(in_channels=1, out_channels=DNA_BP*NUM_OF_1D_CONV_FILTERS, kernel_size=1, | |
stride=1, padding=0, dilation=2) | |
conv_1d_a_output = conv_1d_a(conv_2d_output) | |
conv_1d_b_output = conv_1d_b(conv_2d_output) | |
conv_1d_c_output = conv_1d_c(conv_2d_output) | |
conv_1d_d_output = conv_1d_d(conv_2d_output) | |
self.generator_output += conv_1d_a_output + conv_1d_b_output + conv_1d_c_output + conv_1d_d_output | |
# SOFTMAX activation layer | |
smax = nn.Softmax(dim=1) | |
self.generator_output = smax(self.generator_output) | |
return self.generator_output | |
# See https://zhuanlan.zhihu.com/p/25071913 for a chinese explanation on WGAN-GP | |
# Reused from https://github.com/EmilienDupont/wgan-gp/blob/ef82364f2a2ec452a52fbf4a739f95039ae76fe3/training.py | |
class Trainer: | |
def __init__(self, generator, discriminator, gen_optimizer, dis_optimizer, | |
gp_weight=10, critic_iterations=5, print_every=50, | |
use_cuda=False): | |
self.G = generator | |
self.G_opt = gen_optimizer | |
self.D = discriminator | |
self.D_opt = dis_optimizer | |
self.losses = {'G': [], 'D': [], 'GP': [], 'gradient_norm': []} | |
self.num_steps = 0 | |
self.use_cuda = use_cuda | |
self.gp_weight = gp_weight | |
self.critic_iterations = critic_iterations | |
self.print_every = print_every | |
if self.use_cuda: | |
self.G.cuda() | |
self.D.cuda() | |
def _critic_train_iteration(self, data): | |
""" """ | |
# Get generated data | |
batch_size = data.size()[0] | |
generated_data = self.sample_generator(batch_size) | |
# Calculate probabilities on real and generated data | |
data = Variable(data) | |
if self.use_cuda: | |
data = data.cuda() | |
d_real = self.D(data) | |
d_generated = self.D(generated_data) | |
# Get gradient penalty | |
gradient_penalty = self._gradient_penalty(data, generated_data) | |
self.losses['GP'].append(gradient_penalty.data[0]) | |
# Create total loss and optimize | |
self.D_opt.zero_grad() | |
d_loss = d_generated.mean() - d_real.mean() + gradient_penalty | |
d_loss.backward() | |
self.D_opt.step() | |
# Record loss | |
self.losses['D'].append(d_loss.data[0]) | |
def _generator_train_iteration(self, data): | |
""" """ | |
self.G_opt.zero_grad() | |
# Get generated data | |
batch_size = data.size()[0] | |
generated_data = self.sample_generator(batch_size) | |
# Calculate loss and optimize | |
d_generated = self.D(generated_data) | |
g_loss = - d_generated.mean() | |
g_loss.backward() | |
self.G_opt.step() | |
# Record loss | |
self.losses['G'].append(g_loss.data[0]) | |
def _gradient_penalty(self, real_data, generated_data): | |
batch_size = real_data.size()[0] | |
# https://ai.stackexchange.com/questions/34926/why-do-we-use-a-linear-interpolation-of-fake-and-real-data-to-penalize-the-gradi | |
# Calculate interpolation | |
alpha = torch.rand(batch_size, 1, 1, 1) | |
alpha = alpha.expand_as(real_data) | |
if self.use_cuda: | |
alpha = alpha.cuda() | |
interpolated = alpha * real_data.data + (1 - alpha) * generated_data.data | |
interpolated = Variable(interpolated, requires_grad=True) | |
if self.use_cuda: | |
interpolated = interpolated.cuda() | |
# Calculate probability of interpolated examples | |
prob_interpolated = self.D(interpolated) | |
# Calculate gradients of probabilities with respect to examples | |
gradients = torch_grad(outputs=prob_interpolated, inputs=interpolated, | |
grad_outputs=torch.ones(prob_interpolated.size()).cuda() | |
if self.use_cuda else torch.ones(prob_interpolated.size()), | |
create_graph=True, retain_graph=True)[0] | |
# Gradients have shape (batch_size, num_channels, img_width, img_height), | |
# so flatten to easily take norm per example in batch | |
gradients = gradients.view(batch_size, -1) | |
self.losses['gradient_norm'].append(gradients.norm(2, dim=1).mean().data[0]) | |
# Derivatives of the gradient close to 0 can cause problems because of | |
# the square root, so manually calculate norm and add epsilon | |
gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12) | |
# Return gradient penalty | |
return self.gp_weight * ((gradients_norm - 1) ** 2).mean() | |
def _train_epoch(self, data_loader): | |
for i, data in enumerate(data_loader): | |
self.num_steps += 1 | |
self._critic_train_iteration(data[0]) | |
# Only update generator every |critic_iterations| iterations | |
if self.num_steps % self.critic_iterations == 0: | |
self._generator_train_iteration(data[0]) | |
if i % self.print_every == 0: | |
print("Iteration {}".format(i + 1)) | |
print("D: {}".format(self.losses['D'][-1])) | |
print("GP: {}".format(self.losses['GP'][-1])) | |
print("Gradient norm: {}".format(self.losses['gradient_norm'][-1])) | |
if self.num_steps > self.critic_iterations: | |
print("G: {}".format(self.losses['G'][-1])) | |
def train(self, data_loader, epochs, save_training_gif=True): | |
if save_training_gif: | |
# Fix latents to see how image generation improves during training | |
fixed_latents = Variable(self.G.sample_latent(64)) | |
if self.use_cuda: | |
fixed_latents = fixed_latents.cuda() | |
training_progress_images = [] | |
for epoch in range(epochs): | |
print("\nEpoch {}".format(epoch + 1)) | |
self._train_epoch(data_loader) | |
if save_training_gif: | |
# Generate batch of images and convert to grid | |
img_grid = make_grid(self.G(fixed_latents).cpu().data) | |
# Convert to numpy and transpose axes to fit imageio convention | |
# i.e. (width, height, channels) | |
img_grid = np.transpose(img_grid.numpy(), (1, 2, 0)) | |
# Add image grid to training progress | |
training_progress_images.append(img_grid) | |
if save_training_gif: | |
imageio.mimsave('./training_{}_epochs.gif'.format(epochs), | |
training_progress_images) | |
def sample_generator(self, num_samples): | |
latent_samples = Variable(self.G.sample_latent(num_samples)) | |
if self.use_cuda: | |
latent_samples = latent_samples.cuda() | |
generated_data = self.G(latent_samples) | |
return generated_data | |
def sample(self, num_samples): | |
generated_data = self.sample_generator(num_samples) | |
# Remove color channel | |
return generated_data.data.cpu().numpy()[:, 0, :, :] | |
def print_hi(name): | |
# Use a breakpoint in the code line below to debug your script. | |
print(f'Hi, {name}') # Press ⌘F8 to toggle the breakpoint. | |
# Press the green button in the gutter to run the script. | |
if __name__ == '__main__': | |
print_hi('PyCharm') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment