Skip to content

Instantly share code, notes, and snippets.

@ihsgnef
Last active June 21, 2017 16:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ihsgnef/ac01fbd8eaba01145c8a048a0e8a0678 to your computer and use it in GitHub Desktop.
Save ihsgnef/ac01fbd8eaba01145c8a048a0e8a0678 to your computer and use it in GitHub Desktop.
import random
import numpy as np
from collections import defaultdict
import torch
from torch.autograd import Variable
from vocab import PAD_ID
class Iterator(object):
def __init__(self, dataset, batch_size, bucket_size=4,
shuffle=True):
self.dataset = dataset
self.batch_size = batch_size
self.bucket_size = bucket_size
self.shuffle = shuffle
self.epoch = 0
self.iteration = 0
self.batch_index = 0
self.is_end_epoch = False
self.create_batches()
def create_batches(self):
self.batches = []
buckets = defaultdict(list)
for question, answer in self.dataset:
if len(question) == 0:
continue
while len(question) % self.bucket_size > 0:
question.append(PAD_ID)
buckets[len(question)].append((question, answer))
for samples in buckets.values():
for i in range(0, len(samples), self.batch_size):
questions, answers = zip(*samples[i : i + self.batch_size])
self.batches.append((questions, answers))
@property
def size(self):
return len(self.batches)
def finalize(self, reset=False):
if self.shuffle:
random.shuffle(self.batches)
if reset:
self.epoch = 0
self.iteration = 0
self.batch_index = 0
def next_batch(self, device=-1, train=True):
self.iteration += 1
if self.batch_index == 0:
self.epoch += 1
self.is_end_epoch = (self.batch_index == self.size - 1)
questions, answers = self.batches[self.batch_index]
questions = torch.LongTensor(questions).t() # length, batch_size
answers = torch.LongTensor(answers)
self.batch_index = (self.batch_index + 1) % self.size
if device != -1:
questions = questions.cuda(device)
answers = answers.cuda(device)
else:
questions = questions.contiguous()
answers = answers.contiguous()
questions = Variable(questions, volatile=not train)
answers = Variable(answers, volatile=not train)
return questions, answers
@property
def epoch_detail(self):
return self.size, self.iteration, self.iteration / self.size
import re
from collections import defaultdict
from vocab import Vocab
from question_database import QuestionDatabase
from iterator import Iterator
def preprocess(all_questions):
question_vocab = Vocab()
answer_vocab = Vocab()
for qnum, question in all_questions.items():
if not question.fold == 'guesstrain':
continue
text = ' '.join(question.text.values()).strip()
# do more careful preprocessing here
words = re.sub(r'\W+', ' ', text).split()
for word in words:
question_vocab.add(word)
answer_vocab.add(question.page)
question_vocab.finish()
answer_vocab.finish()
dataset = defaultdict(lambda: [])
for qnum, question in all_questions.items():
if not question.fold in ['guesstrain', 'guessdev']:
continue
text = ' '.join(question.text.values()).strip()
# do more careful preprocessing here
words = re.sub(r'\W+', ' ', text).split()
words = question_vocab.sent2ids(words)
answer = answer_vocab.word2id(question.page)
if question.fold == 'guesstrain':
dataset['train'].append((words, answer))
if question.fold == 'guessdev':
dataset['dev'].append((words, answer))
return question_vocab, answer_vocab, dataset
def main():
batch_size = 64
all_questions = QuestionDatabase().all_questions() # dict
question_vocab, answer_vocab, dataset = preprocess(all_questions)
print(len(dataset['train']), len(dataset['dev']))
print(len(question_vocab), len(answer_vocab))
print(dataset['train'][0])
iterators = dict()
for fold in dataset.keys():
iterators[fold] = Iterator(dataset[fold], batch_size)
question_batch, answer_batch = iterators['train'].next_batch()
print(question_batch.size(), answer_batch.size())
for i in range(len(iterators['train'])):
question_batch, answer_batch = iterators['train'].next_batch()
# training
if __name__ == '__main__':
main()
import sqlite3
class Question:
def __init__(self, qnum, answer, category, naqt, protobowl,
tournaments, page, fold):
self.qnum = qnum
self.answer = answer
self.category = category
self.naqt = naqt
self.protobowl = protobowl
self.tournaments = tournaments
self.page = page
self.fold = fold
self.text = {}
def add_text(self, sent, text):
self.text[sent] = text
class QuestionDatabase:
def __init__(self, location='2017_05_25.db'):
self._conn = sqlite3.connect(location)
def query(self, command, arguments):
questions = {}
c = self._conn.cursor()
command = 'select id, page, category, answer, ' + \
'tournament, naqt, protobowl, fold ' + command
c.execute(command, arguments)
for qnum, page, _, answer, tournaments, naqt, protobowl, fold in c:
questions[qnum] = Question(qnum, answer, None, naqt, protobowl, tournaments, page, fold)
for qnum in questions:
command = 'select sent, raw from text where question=? order by sent asc'
c.execute(command, (qnum, ))
for sentence, text in c:
questions[qnum].add_text(sentence, text)
return questions
def all_questions(self):
return self.query('FROM questions where page != ""', ())
if __name__ == '__main__':
db = QuestionDatabase()
qs = db.all_questions()
from collections import defaultdict
PAD_ID = 0
UNK_ID = 1
PAD = "<PAD>"
UNK = "<UNK>"
class Vocab(object):
def __init__(self):
self.word_count = defaultdict(lambda: 0)
def add(self, word):
self.word_count[word] += 1
def finish(self, size=50000):
word_count = dict(self.word_count)
words = sorted(word_count.items(), key=lambda x: x[1], reverse=True)
self.i2w = [PAD, UNK]
self.i2w += [x[0] for x in words[:size]]
self.w2i = dict((w,i) for i, w in enumerate(self.i2w))
def __len__(self):
return len(self.i2w)
def word2id(self, word):
return self.w2i.get(word, UNK_ID)
def id2word(self, i):
return self.i2w[i]
def sent2ids(self, sentence):
return [self.word2id(w) for w in sentence]
def ids2sent(self, ids):
return [self.id2word(x) for x in ids if x != PAD_ID]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment