Skip to content

Instantly share code, notes, and snippets.

@syuntoku14
Created January 13, 2020 09:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save syuntoku14/edcc327763bbce29056ba0ef2f665062 to your computer and use it in GitHub Desktop.
Save syuntoku14/edcc327763bbce29056ba0ef2f665062 to your computer and use it in GitHub Desktop.
import argparse
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
from pytorch_pretrained_bert import BertModel, BertTokenizer
import torch.nn.utils.rnn as rnn
import torch.nn.functional as F
import torch.utils.data as data
import sys
from torch.nn.utils import clip_grad_norm_
import parser
import torch
import os
def is_word_unfeasible(word):
def is_ascii(word):
return all(ord(c) < 128 for c in word)
return ("unused" in word
or "#" in word
or not is_ascii(word)
or len(word) < 3)
class BertDataset(Dataset):
def __init__(self, device):
self.device = device
# get the tokenized words.
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
# load BERT base model
self.bert = BertModel.from_pretrained("bert-base-uncased").to(device)
for param in self.bert.parameters():
param.requires_grad = False
self.bert.eval()
# input characters
self.CHAR_VOCAB_SIZE = 128
words = self.tokenizer.vocab.keys()
self.vocabs = [word for word in words if not is_word_unfeasible(word)]
self.chars = [torch.LongTensor([ord(c) for c in word])
for word in self.vocabs]
self.chars = rnn.pad_sequence(self.chars).to(self.device).T
# word embeddings of bert
ids = torch.LongTensor(
self.tokenizer.convert_tokens_to_ids(self.vocabs)).to(self.device)
self.word_embed = self.bert.embeddings.word_embeddings(
ids.unsqueeze(0)).squeeze(0)
def __len__(self):
return len(self.vocabs)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
return self.chars[idx], self.word_embed[idx]
class Conv1dBlockBN(nn.Module):
def __init__(self, in_channel, out_channel, kernel_size, stride, p=0.0):
super().__init__()
self.conv = nn.Sequential(
nn.Conv1d(in_channel, out_channel,
kernel_size=kernel_size, stride=stride),
nn.Dropout(p),
nn.PReLU(),
nn.BatchNorm1d(out_channel)
)
def forward(self, x):
x = self.conv(x)
return x
class CNN_LM(nn.Module):
def __init__(self, char_vocab_size, char_len, embed_dim, chan_size, hid_size, bert_hid_size):
super().__init__()
self.embedding = nn.Embedding(char_vocab_size, embed_dim)
convs = []
for i in range(char_len - 1):
if i == 0:
convs.append(Conv1dBlockBN(embed_dim, chan_size, 2, stride=1))
else:
convs.append(Conv1dBlockBN(chan_size, chan_size, 2, stride=1))
self.convs = nn.Sequential(*convs)
self.fc1 = nn.Linear(chan_size, hid_size)
self.fc2 = nn.Linear(hid_size, bert_hid_size)
def forward(self, x):
# (batch_size, embed_dim, context_width)
x = self.embedding(x).permute(0, 2, 1)
x = self.convs(x) # (batch_size, chan_size, 1)
x = x.squeeze(2) # (batch_size, chan_size)
x = F.relu(self.fc1(x)) # (batch_size, hid_size)
x = self.fc2(x) # (batch_size, vocab_size)
return x
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--num_epochs', type=int, default=100)
parser.add_argument('--embed_size', type=int, default=8)
parser.add_argument('--hidden_size', type=int, default=256)
parser.add_argument('--channel_size', type=int, default=32)
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--learning_rate', type=float, default=0.001)
args = parser.parse_args()
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
embed_size = args.embed_size
num_epochs = args.num_epochs
dataset = BertDataset(device)
# test train
SAMPLE_SIZE = len(dataset)
TRAIN_SIZE = int(SAMPLE_SIZE * 0.8)
train_dataset, val_dataset = data.random_split(dataset, [TRAIN_SIZE, SAMPLE_SIZE - TRAIN_SIZE])
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0)
val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0)
CHAR_VOCAB_SIZE = 128
BERT_EMBED_DIM = 768
model = CNN_LM(char_vocab_size=CHAR_VOCAB_SIZE,
char_len=dataset.chars.shape[1], embed_dim=args.embed_size,
chan_size=args.channel_size, hid_size=args.hidden_size,
bert_hid_size=BERT_EMBED_DIM)
model.to(device)
# Loss and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
print("training start")
# Train the model
for epoch in range(args.num_epochs):
model.train()
for batch, (inputs, targets) in enumerate(train_dataloader):
# Forward pass
outputs = model(inputs)
loss = criterion(outputs, targets)
# Backward and optimize
model.zero_grad()
loss.backward()
optimizer.step()
if batch % 20 == 0:
print('Training: Epoch [{}/{}], Step[{}/{}], Loss: {:.4f}'
.format(epoch+1, args.num_epochs, batch, len(train_dataloader), loss.item()))
test_loss = 0
model.eval()
for batch, (inputs, targets) in enumerate(val_dataloader):
# Forward pass
outputs = model(inputs)
test_loss += criterion(outputs, targets)
print('Test: Epoch {}, Loss: {:.4f}'
.format(epoch+1, test_loss.item() / len(val_dataloader)))
# Save the model checkpoints
torch.save(model.state_dict(),
'data/bert_cnn.ckpt')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment