This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
ID | Prediction | |
---|---|---|
5865780764279606380_1 | 0 | |
5865752664273366197_0 | 0 | |
5865752674274689480_0 | 0 | |
5865687714272887778_1 | 0 | |
5865738854271318689_1 | 0 | |
5865942504274625239_1 | 0 | |
5865738824272575861_0 | 0 | |
5865943114272638580_0 | 0 | |
5865784484277061032_2 | 0 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def predict_inference_src(self, src: str, max_len: int = 140) -> str: | |
""" | |
Used on inference to predict source | |
:param src: String text as source | |
:param max_len: max length of the generator | |
:return: predicted text | |
""" | |
src_field, tgt_field = \ | |
self.constructed_iterator_field['src_field'], self.constructed_iterator_field['tgt_field'] | |
end_token_id = src_field.vocab.stoi['</s>'] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
No | Arch | BLEU | |
---|---|---|---|
1 | GRU Seq2seq + Attention | 5.696920130368399 | |
2 | Transformer | 16.243497976699246 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pytorch_lightning.core.lightning import LightningModule | |
from recibrew.data_util import construct_torchtext_iterator | |
from recibrew.nn.gru_bahdanau import Encoder, Decoder | |
from recibrew.nn.transformers import FullTransformer | |
import torch | |
from torch.optim import AdamW | |
class TransformersLightning(LightningModule): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class FullTransformer(Module): | |
def __init__(self, num_vocab, num_embedding=128, dim_feedforward=512, num_encoder_layer=4, | |
num_decoder_layer=4, dropout=0.3, padding_idx=1, max_seq_len=140, nhead=8): | |
super(FullTransformer, self).__init__() | |
self.padding_idx = padding_idx | |
# [x : seq_len, batch_size ] | |
self.inp_embedding = Embedding(num_vocab , num_embedding, padding_idx=padding_idx) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class GRUBahdanauLightning(LightningModule): | |
""" | |
GRU + Bahdanau attention, Research environment | |
""" | |
def __init__(self, train_csv='../data/processed/train.csv', dev_csv='../data/processed/dev.csv', | |
test_csv='../data/processed/test.csv', lr=1e-3, gru_params=None, padding_idx=1, | |
max_len=140): | |
""" | |
:param train_csv: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def construct_torchtext_iterator(train_csv: str, dev_csv: str, test_csv: str, device: str = 'cuda', | |
batch_size: int = 64, max_vocab: int = 3000, fix_length=144) -> Dict[str, Any]: | |
""" | |
Construct the iterator used to train the data. | |
:param train_csv: train_csv file csv | |
:param dev_csv: dev_csv file csv | |
:param test_csv: test_csv file csv | |
:param device: device of the torch tensor ('cpu' or 'cuda') | |
:param batch_size : the batch size of each iterator (train, dev, test) | |
:param max_vocab : max vocab in dictionary |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
means = -1 | |
count = 0 | |
fig, axes = plt.subplots(1, 10, gridspec_kw = {'wspace':0.1, 'hspace':0.1}, figsize=(16,16)) | |
while means < 1 : | |
random_latent_vectors = np.random.normal(size = (16, latent_dim), loc=means, | |
scale=0.0) | |
random_latent_vectors = random_latent_vectors.mean(axis=0) | |
generated_images = generator.predict(np.array([random_latent_vectors])) | |
axes[count].set_xticklabels([]) | |
axes[count].set_yticklabels([]) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
x_train = data_train_gan | |
iterations = 15000 | |
batch_size = 32 | |
save_dir = '.' | |
start = 0 | |
for step in tqdm_notebook(range(iterations)): | |
random_latent_vectors = np.random.normal(size = (batch_size, latent_dim)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.optimizers import Adam, RMSprop | |
discriminator.trainable = False | |
gan_input = keras.Input(shape=(latent_dim,)) | |
gan_output = discriminator(generator(gan_input)) | |
gan = keras.models.Model(gan_input, gan_output) | |
gan_optimizer = keras.optimizers.RMSprop(lr = 0.0004, clipvalue = 1.0, decay = 1e-8) | |
gan.compile(optimizer = gan_optimizer, loss = 'binary_crossentropy') |
NewerOlder