Skip to content

Instantly share code, notes, and snippets.

@yassersouri
Forked from SNagappan/README.md
Created December 2, 2015 11:48
Embed
What would you like to do?
bAbI

##Model

This is an implementation of Facebook's baseline GRU/LSTM model on the bAbI dataset Weston et al. 2015. It includes an interactive demo.

The bAbI dataset contains 20 different question answering tasks.

Model script

The model training script train.py and demo script demo.py are included below.

Instructions

First run the train.py script to get a pickle file of model weights. Use the command line arguments --rlayer_type to choose between LSTMs or GRUs, --save_path to specify the output pickle file location, and -t to specify which bAbI task to run.

python examples/babi/train.py -e 20 --rlayer_type gru --save_path babi.p -t 15

Second run the demo with the newly created pickle file.

python examples/babi/demo.py -t 15 --rlayer_type gru --model_weights babi.p
Task is en/qa15_basic-deduction

Example from test set:

Story
Wolves are afraid of mice.
Sheep are afraid of mice.
Winona is a sheep.
Mice are afraid of cats.
Cats are afraid of wolves.
Jessica is a mouse.
Emily is a cat.
Gertrude is a wolf.

Question
What is emily afraid of?

Answer
wolf

Please enter a story:

At which point you can play around with your own stories, questions, and answers.

Trained weights

The trained weights file for a GRU network trained on task 3 can be downloaded from AWS using the following link: trained model weights on task 3.

Performance

Task Number FB LSTM Baseline Neon QA GRU
QA1 - Single Supporting Fact 50 47.9
QA2 - Two Supporting Facts 20 29.8
QA3 - Three Supporting Facts 20 20.0
QA4 - Two Arg. Relations 61 69.8
QA5 - Three Arg. Relations 70 56.4
QA6 - Yes/No Questions 48 49.1
QA7 - Counting 49 76.5
QA8 - Lists/Sets 45 68.9
QA9 - Simple Negation 64 62.8
QA10 - Indefinite Knowledge 44 45.3
QA11 - Basic Coreference 72 67.6
QA12 - Conjunction 74 63.9
QA13 - Compound Coreference 94 91.9
QA14 - Time Reasoning 27 36.8
QA15 - Basic Deduction 21 51.4
QA16 - Basic Induction 23 50.1
QA17 - Positional Reasoning 51 49.0
QA18 - Size Reasoning 52 90.5
QA19 - Path Finding 8 9.0
QA20 - Agent's Motivations 91 95.6

Citation

https://research.facebook.com/researchers/1543934539189348
Weston, Jason, et al. "Towards AI-complete question answering: a set of prerequisite toy tasks." arXiv preprint arXiv:1502.05698 (2015).
#!/usr/bin/env python
# ----------------------------------------------------------------------------
# Copyright 2015 Nervana Systems Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----------------------------------------------------------------------------
"""
Interactive demo based on Facebook Q&A dataset: bAbI
Reference:
"Towards AI-Complete Question Answering: A Set of Prerequisite Toy Tasks"
http://arxiv.org/abs/1502.05698
Usage:
use -t to specify which bAbI task to run
python examples/babi/demo.py -t 1 --rlayer_type gru --model_weights babi.p
"""
import numpy as np
from util import create_model, babi_handler
from neon.backends import gen_backend
from neon.data import BABI, QA
from neon.data.text import Text
from neon.util.argparser import NeonArgparser, extract_valid_args
# parse the command line arguments
parser = NeonArgparser(__doc__)
parser.add_argument('-t', '--task', type=int, default='1', choices=xrange(1, 21),
help='the task ID to train/test on from bAbI dataset (1-20)')
parser.add_argument('--rlayer_type', default='gru', choices=['gru', 'lstm'],
help='type of recurrent layer to use (gru or lstm)')
parser.add_argument('--model_weights',
help='pickle file of trained weights')
args = parser.parse_args(gen_be=False)
# setup backend
be = gen_backend(**extract_valid_args(args, gen_backend))
be.bsz = 1
# load the bAbI dataset
babi = babi_handler(args.data_dir, args.task)
valid_set = QA(*babi.test)
# create model
model_inference = create_model(babi.vocab_size, args.rlayer_type)
model_inference.load_weights(args.model_weights)
model_inference.initialize(dataset=valid_set)
ex_story, ex_question, ex_answer = babi.test_parsed[0]
stitch_sentence = lambda words: \
" ".join(words).replace(" ?", "?").replace(" .", ".\n").replace("\n ", "\n")
print "\nExample from test set:"
print "\nStory"
print stitch_sentence(ex_story)
print "Question"
print stitch_sentence(ex_question)
print "\nAnswer"
print ex_answer
while True:
# ask user for story and question
story_lines = []
line = raw_input("\nPlease enter a story:\n")
while line != "":
story_lines.append(line)
line = raw_input()
story = ("\n".join(story_lines)).strip()
question = raw_input("Please enter a question:\n")
# convert user input into a suitable network input
vectorize = lambda words, max_len: \
be.array(Text.pad_sentences([babi.words_to_vector(BABI.tokenize(words))], max_len))
s = vectorize(story, babi.story_maxlen)
q = vectorize(question, babi.query_maxlen)
# get prediction probabilities with forward propagation
probs = model_inference.fprop(x=(s, q), inference=True).get()
# get top k answers
top_k = -min(5, babi.vocab_size)
max_indices = np.argpartition(probs, top_k, axis=0)[top_k:]
max_probs = probs[max_indices]
sorted_idx = max_indices[np.argsort(max_probs, axis=0)]
print "\nAnswer:"
for idx in reversed(sorted_idx):
idx = int(idx)
print babi.index_to_word[idx], float(probs[idx])
#!/usr/bin/env python
# ----------------------------------------------------------------------------
# Copyright 2015 Nervana Systems Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----------------------------------------------------------------------------
"""
Example that trains on Facebook Q&A dataset: bAbI
Task Number | FB LSTM Baseline | Neon QA GRU
--- | --- | ---
QA1 - Single Supporting Fact | 50 | 47.9
QA2 - Two Supporting Facts | 20 | 29.8
QA3 - Three Supporting Facts | 20 | 20.0
QA4 - Two Arg. Relations | 61 | 69.8
QA5 - Three Arg. Relations | 70 | 56.4
QA6 - Yes/No Questions | 48 | 49.1
QA7 - Counting | 49 | 76.5
QA8 - Lists/Sets | 45 | 68.9
QA9 - Simple Negation | 64 | 62.8
QA10 - Indefinite Knowledge | 44 | 45.3
QA11 - Basic Coreference | 72 | 67.6
QA12 - Conjunction | 74 | 63.9
QA13 - Compound Coreference | 94 | 91.9
QA14 - Time Reasoning | 27 | 36.8
QA15 - Basic Deduction | 21 | 51.4
QA16 - Basic Induction | 23 | 50.1
QA17 - Positional Reasoning | 51 | 49.0
QA18 - Size Reasoning | 52 | 90.5
QA19 - Path Finding | 8 | 9.0
QA20 - Agent's Motivations | 91 | 95.6
Reference:
"Towards AI-Complete Question Answering: A Set of Prerequisite Toy Tasks"
http://arxiv.org/abs/1502.05698
Usage:
use -t to specify which bAbI task to run
python examples/babi/train.py -e 20 --rlayer_type gru --save_path babi_lstm.p -t 1
"""
from util import create_model, babi_handler
from neon.backends import gen_backend
from neon.data import QA
from neon.layers import GeneralizedCost
from neon.optimizers import Adam
from neon.transforms import Accuracy, CrossEntropyMulti
from neon.callbacks.callbacks import Callbacks
from neon.util.argparser import NeonArgparser, extract_valid_args
# parse the command line arguments
parser = NeonArgparser(__doc__)
parser.add_argument('-t', '--task', type=int, default='1', choices=xrange(1, 21),
help='the task ID to train/test on from bAbI dataset (1-20)')
parser.add_argument('--rlayer_type', default='gru', choices=['gru', 'lstm'],
help='type of recurrent layer to use (gru or lstm)')
args = parser.parse_args(gen_be=False)
# Override save path if None
if args.save_path is None:
args.save_path = 'babi.p'
if args.callback_args['save_path'] is None:
args.callback_args['save_path'] = args.save_path
# setup backend
args.batch_size = 32
be = gen_backend(**extract_valid_args(args, gen_backend))
# load the bAbI dataset
babi = babi_handler(args.data_dir, args.task)
train_set = QA(*babi.train)
valid_set = QA(*babi.test)
# create model
model = create_model(babi.vocab_size, args.rlayer_type)
# setup callbacks
callbacks = Callbacks(model, train_set, eval_set=valid_set, **args.callback_args)
# train model
model.fit(train_set,
optimizer=Adam(),
num_epochs=args.epochs,
cost=GeneralizedCost(costfunc=CrossEntropyMulti()),
callbacks=callbacks)
# output accuracies
print('Train Accuracy = %.1f%%' % (model.eval(train_set, metric=Accuracy())*100))
print('Test Accuracy = %.1f%%' % (model.eval(valid_set, metric=Accuracy())*100))
#!/usr/bin/env python
# ----------------------------------------------------------------------------
# Copyright 2015 Nervana Systems Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----------------------------------------------------------------------------
"""
Utility functions for bAbI example and demo.
"""
from neon.data import BABI
from neon.initializers import GlorotUniform, Uniform, Orthonormal
from neon.layers import Affine, GRU, LookupTable, MergeMultistream, LSTM
from neon.models import Model
from neon.transforms import Logistic, Softmax, Tanh
# list of bAbI tasks
subset = 'en'
task_list = [
'qa1_single-supporting-fact',
'qa2_two-supporting-facts',
'qa3_three-supporting-facts',
'qa4_two-arg-relations',
'qa5_three-arg-relations',
'qa6_yes-no-questions',
'qa7_counting',
'qa8_lists-sets',
'qa9_simple-negation',
'qa10_indefinite-knowledge',
'qa11_basic-coreference',
'qa12_conjunction',
'qa13_compound-coreference',
'qa14_time-reasoning',
'qa15_basic-deduction',
'qa16_basic-induction',
'qa17_positional-reasoning',
'qa18_size-reasoning',
'qa19_path-finding',
'qa20_agents-motivations',
]
def babi_handler(data_dir, task_number):
"""
Handle for bAbI task.
Args:
data_dir (string) : Path to bAbI data directory.
task_number (int) : The task ID from the bAbI dataset (1-20).
Returns:
BABI : Handler for bAbI task.
"""
task = task_list[task_number - 1]
return BABI(path=data_dir, task=task, subset=subset)
def create_model(vocab_size, rlayer_type):
"""
Create LSTM/GRU model for bAbI dataset.
Args:
vocab_size (int) : String of bAbI data.
rlayer_type (string) : Type of recurrent layer to use (gru or lstm).
Returns:
Model : Model of the created network
"""
# recurrent layer parameters (default gru)
rlayer_obj = GRU if rlayer_type == 'gru' else LSTM
rlayer_params = dict(output_size=100, reset_cells=True,
init=GlorotUniform(), init_inner=Orthonormal(0.5),
activation=Tanh(), gate_activation=Logistic())
# if using lstm, swap the activation functions
if rlayer_type == 'lstm':
rlayer_params.update(dict(activation=Logistic(), gate_activation=Tanh()))
# lookup layer parameters
lookup_params = dict(vocab_size=vocab_size, embedding_dim=50, init=Uniform(-0.05, 0.05))
# Model construction
story_path = [LookupTable(**lookup_params), rlayer_obj(**rlayer_params)]
query_path = [LookupTable(**lookup_params), rlayer_obj(**rlayer_params)]
layers = [MergeMultistream(layers=[story_path, query_path], merge="stack"),
Affine(vocab_size, init=GlorotUniform(), activation=Softmax())]
return Model(layers=layers)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment