Skip to content

Instantly share code, notes, and snippets.

@ajfisch
Created May 8, 2017 15:55
Show Gist options
  • Save ajfisch/d447b9fd610b1d9868843350145fc6e3 to your computer and use it in GitHub Desktop.
Save ajfisch/d447b9fd610b1d9868843350145fc6e3 to your computer and use it in GitHub Desktop.
Interactive drqa model with ParlAI
# Run: python path/to/file --pretrained_model path/to/model
#
# Example interaction:
# Context: I was thirsty today. So I went to the market and bought some water.
# Question: What did I buy?
# Reply: some water
import torch
import logging
from parlai.agents.drqa.agents import DocReaderAgent
from parlai.core.params import ParlaiParser
def main(opt):
# Load document reader (need pretrained model)
assert('pretrained_model' in opt)
doc_reader = DocReaderAgent(opt)
# Log params
logger.info('[ Created with options: ] %s' %
''.join(['\n{}\t{}'.format(k, v) for k, v in opt.items()]))
while True:
context = input('Context: ')
question = input('Question: ')
observation = {'text': '\n'.join([context, question]),
'episode_done': True}
doc_reader.observe(observation)
reply = doc_reader.act()
print('Reply: %s' % reply['text'])
if __name__ == '__main__':
# Get command line arguments
argparser = ParlaiParser()
DocReaderAgent.add_cmdline_args(argparser)
opt = argparser.parse_args()
# Set logging (only stderr)
logger = logging.getLogger('DrQA')
logger.setLevel(logging.INFO)
fmt = logging.Formatter('%(asctime)s: %(message)s', '%m/%d/%Y %I:%M:%S %p')
console = logging.StreamHandler()
console.setFormatter(fmt)
logger.addHandler(console)
# Set cuda
opt['cuda'] = not opt['no_cuda'] and torch.cuda.is_available()
if opt['cuda']:
logger.info('[ Using CUDA (GPU %d) ]' % opt['gpu'])
torch.cuda.set_device(opt['gpu'])
# Run!
main(opt)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment