Skip to content

Instantly share code, notes, and snippets.

@snakers4
Created March 1, 2019 09:14
Show Gist options
  • Star 5 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save snakers4/f2188adf217baabad2ff6733d48cfaef to your computer and use it in GitHub Desktop.
Save snakers4/f2188adf217baabad2ff6733d48cfaef to your computer and use it in GitHub Desktop.
Best pretraining for Russian language - embedding bag interfaces
class BertEmbeddingBag(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.
"""
def __init__(self, config):
super(BertEmbeddingBag, self).__init__()
# self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
ngram_matrix=np.load(config.ngram_matrix_path)
self.old_bag = config.old_bag
if self.old_bag:
self.embedding_bag = OldFastTextEmbeddingBag(upkl(config.ngram_dict_path),
ngram_matrix,
config.device)
else:
self.embedding_bag = FastTextEmbeddingBag(ngram_matrix,
config.device)
assert ngram_matrix.shape[1] == config.emb_size
del ngram_matrix
self.hidden_size = config.hidden_size
self.emb_size = config.emb_size
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.emb_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.emb_size)
self.linear = nn.Linear(self.emb_size, self.hidden_size)
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids,
token_type_ids=None,
pad_ids=None
):
# input_ids is a sequence of chars for embedding bag (!)
if self.old_bag:
seq_length = token_type_ids.size(1)
else:
seq_length = input_ids.size(1)
assert len(pad_ids.size()) == 2
#assert pad_ids.size(1) == input_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=token_type_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(token_type_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
# words_embeddings = self.word_embeddings(input_ids)
words_embeddings = self.embedding_bag(input_ids, pad_ids).view((-1,
seq_length,
self.emb_size))
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = words_embeddings + position_embeddings + token_type_embeddings
if self.hidden_size != self.emb_size:
embeddings = self.linear(embeddings)
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class FastTextEmbeddingBag(EmbeddingBag):
def __init__(self, input_matrix, device):
#self.ngram_dict = ngram_dict
self.device = device
# Тут с размером матрицы нужно определиться
super().__init__(input_matrix.shape[0], input_matrix.shape[1])
self.weight.data.copy_(torch.FloatTensor(input_matrix))
def forward(self, ids, pad_ids):
#word_subinds = np.empty([0], dtype=np.int64)
#word_offsets = [0]
#for word in words:
#subinds = [self.ngram_dict[gram] for gram in word_ngrams(word) if gram in self.ngram_dict]
#if subinds == []:
#subinds.append(self.ngram_dict['#UNK#'])
#word_subinds = np.concatenate((word_subinds, subinds))
#word_offsets.append(word_offsets[-1] + len(subinds))
#word_offsets = word_offsets[:-1]
device=ids.device
pad_ids = pad_ids + 1
ids = torch.reshape(ids, (-1, ids.shape[-1]))
pad_ids = pad_ids.view(-1)
word_subinds = torch.FloatTensor([]).to(device)
word_offsets = torch.FloatTensor([0.0]).to(device)
for i, word in enumerate(ids):
word_subinds = torch.cat((word_subinds, word.float()[:pad_ids[i]]))
word_offsets = torch.cat((word_offsets, torch.cumsum(pad_ids.float(), 0)))
word_offsets = word_offsets[:-1]
ind = word_subinds.long().to(device)
offsets = word_offsets.long().to(device)
return super().forward(ind, offsets)
class OldFastTextEmbeddingBag(EmbeddingBag):
def __init__(self, ngram_dict, input_matrix, device):
self.ngram_dict = ngram_dict
self.device = device
# Тут с размером матрицы нужно определиться
super().__init__(input_matrix.shape[0], input_matrix.shape[1])
self.weight.data.copy_(torch.FloatTensor(input_matrix))
def forward(self, words, pad_ids=None):
word_subinds = np.empty([0], dtype=np.int64)
word_offsets = [0]
for word in words:
subinds = [self.ngram_dict[gram] for gram in word_ngrams(word) if gram in self.ngram_dict]
if subinds == []:
subinds.append(self.ngram_dict['#UNK#'])
word_subinds = np.concatenate((word_subinds, subinds))
word_offsets.append(word_offsets[-1] + len(subinds))
word_offsets = word_offsets[:-1]
ind = torch.LongTensor(word_subinds).to(self.device)
offsets = torch.LongTensor(word_offsets).to(self.device)
return super().forward(ind, offsets)
class BertEmbeddingsWarmStart(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.
"""
def __init__(self, config):
super(BertEmbeddingsWarmStart, self).__init__()
# initialize with pre-trained embeddings
ngram_matrix = torch.from_numpy(np.load(config.ngram_matrix_path)).float()
self.word_embeddings = nn.Embedding.from_pretrained(ngram_matrix)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids, token_type_ids=None):
seq_length = input_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = words_embeddings + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment