Skip to content

Instantly share code, notes, and snippets.

@MaximumEntropy
Created October 22, 2017 20:34
Show Gist options
  • Save MaximumEntropy/68a54c2081fea53e0f798bef4db08a3d to your computer and use it in GitHub Desktop.
Save MaximumEntropy/68a54c2081fea53e0f798bef4db08a3d to your computer and use it in GitHub Desktop.
class DataIterator(object):
"""Data Iterator."""
def _trim_vocab(self, vocab, vocab_size):
# Discard start, end, pad and unk tokens if already present
if '<s>' in vocab:
del vocab['<s>']
if '<pad>' in vocab:
del vocab['<pad>']
if '</s>' in vocab:
del vocab['</s>']
if '<unk>' in vocab:
del vocab['<unk>']
word2id = {
'<s>': 0,
'<pad>': 1,
'</s>': 2,
'<unk>': 3,
}
id2word = {
0: '<s>',
1: '<pad>',
2: '</s>',
3: '<unk>',
}
sorted_word2id = sorted(
vocab.items(),
key=operator.itemgetter(1),
reverse=True
)
if vocab_size != -1:
sorted_words = [x[0] for x in sorted_word2id[:vocab_size]]
else:
sorted_words = [x[0] for x in sorted_word2id]
for ind, word in enumerate(sorted_words):
word2id[word] = ind + 4
for ind, word in enumerate(sorted_words):
id2word[ind + 4] = word
return word2id, id2word
def construct_vocab(self, sentences, vocab_size, lowercase=False):
"""Create vocabulary."""
vocab = {}
for sentence in sentences:
if isinstance(sentence, str):
if lowercase:
sentence = sentence.lower()
sentence = sentence.split()
for word in sentence:
if word not in vocab:
vocab[word] = 1
else:
vocab[word] += 1
word2id, id2word = self._trim_vocab(vocab, vocab_size)
return word2id, id2word
class ParallelCorpusDataIterator(DataIterator):
"""Parallel corpus data iterator."""
def __init__(self, src, trg, vocab_size, common_vocab=False):
"""Initialize params."""
self.src = src
self.trg = trg
self.vocab_size = vocab_size
self.common_vocab = common_vocab
self.read_parallel_data()
def read_parallel_data(self):
"""Read data from files."""
print('Reading source file ...')
src_lines = [line.strip().split() for line in open(self.src, 'r')]
print('Reading target file ...')
trg_lines = [line.strip().split() for line in open(self.trg, 'r')]
if not self.common_vocab:
print('Building source vocabulary ...')
src_word2id, src_id2word = self.construct_vocab(
src_lines, self.vocab_size
)
print('Building target vocabulary ...')
trg_word2id, trg_id2word = self.construct_vocab(
trg_lines, self.vocab_size
)
self.src = {
'data': src_lines,
'word2id': src_word2id,
'id2word': src_id2word
}
self.trg = {
'data': trg_lines,
'word2id': trg_word2id,
'id2word': trg_id2word
}
else:
print('Building common vocabulary ...')
word2id, id2word = self.construct_vocab(
src_lines + trg_lines, self.vocab_size
)
self.src = {
'data': src_lines, 'word2id': word2id, 'id2word': id2word
}
self.trg = {
'data': trg_lines, 'word2id': word2id, 'id2word': id2word
}
def shuffle_dataset(self):
"""Shuffle dataset."""
self.src['data'], self.trg['data'] = shuffle(
self.src['data'], self.trg['data']
)
def get_parallel_minibatch(
self, index, batch_size, max_len_src, max_len_trg
):
"""Prepare minibatch."""
src_lines = [
['<s>'] + line[:max_len_src] + ['</s>']
for line in self.src['data'][index: index + batch_size]
]
trg_lines = [
['<s>'] + line[:max_len_trg] + ['</s>']
for line in self.trg['data'][index: index + batch_size]
]
src_lens = [len(line) for line in src_lines]
sorted_indices = np.argsort(src_lens)[::-1]
sorted_src_lines = [src_lines[idx] for idx in sorted_indices]
sorted_trg_lines = [trg_lines[idx] for idx in sorted_indices]
sorted_src_lens = [len(line) for line in sorted_src_lines]
sorted_trg_lens = [len(line) for line in sorted_trg_lines]
max_src_len = max(sorted_src_lens)
max_trg_len = max(sorted_trg_lens)
input_lines_src = [
[self.src['word2id'][w] if w in self.src['word2id'] else self.src['word2id']['<unk>'] for w in line] +
[self.src['word2id']['<pad>']] * (max_src_len - len(line))
for line in sorted_src_lines
]
input_lines_trg = [
[self.trg['word2id'][w] if w in self.trg['word2id'] else self.trg['word2id']['<unk>'] for w in line[:-1]] +
[self.trg['word2id']['<pad>']] * (max_trg_len - len(line))
for line in sorted_trg_lines
]
output_lines_trg = [
[self.trg['word2id'][w] if w in self.trg['word2id'] else self.trg['word2id']['<unk>'] for w in line[1:]] +
[self.trg['word2id']['<pad>']] * (max_trg_len - len(line))
for line in sorted_trg_lines
]
input_lines_src = Variable(torch.LongTensor(input_lines_src)).cuda()
input_lines_trg = Variable(torch.LongTensor(input_lines_trg)).cuda()
output_lines_trg = Variable(torch.LongTensor(output_lines_trg)).cuda()
sorted_src_lens = Variable(torch.LongTensor(sorted_src_lens), volatile=True).squeeze().cuda()
return {
'input_src': input_lines_src,
'input_trg': input_lines_trg,
'output_trg': output_lines_trg,
'src_lens': sorted_src_lens
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment