Last active
February 18, 2020 13:40
-
-
Save noskill/3ea20a3331b30e02d5799dfa9f82c558 to your computer and use it in GitHub Desktop.
GQA to logic
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import re | |
from collections import defaultdict | |
import json | |
two_args_re = re.compile('^([\w\s]+)\((\d+)\)') | |
two_args_re_no = re.compile('^([\w\s]+)\((-)\)') | |
many_objects = re.compile('^(\w+)\s\((\d+(\,\d+)+)\)') | |
one_word = re.compile('^(\w+)($)') | |
s_re = re.compile('([s|o])\s\((.*)\)') | |
varnames = ['$X', '$Y', '$Z', '$E', '$R'] | |
relations = { | |
'to the right of': 'right_of', | |
'on': 'on', | |
'to the left of': 'left_of', | |
'by': 'by', | |
'in': 'in', | |
'carrying': 'carrying', | |
'behind': 'behind', | |
'wearing': 'wearing', | |
'near': 'near', | |
'in front of': 'front_of', | |
'with': 'with', | |
'full of': 'full_of', | |
'on top of': 'on_top_of', | |
'above': 'above', | |
'sitting on': 'sits_on', | |
'covered in': 'covered_in', | |
'of': 'of', | |
'feeding': 'feeding', | |
'holding': 'holding', | |
'drinking from': 'drinks_from', | |
'lying in': 'lays_in', | |
'riding': 'rides', | |
'using': 'uses', | |
'lying on': 'lays_on', | |
'below': 'below', | |
'sitting on top of': 'sits_on_top', | |
'hanging on': 'hangs_on', | |
'leaning against': 'leans_against', | |
'kicking': 'kicks', | |
'next to': 'next_to', | |
'eating': 'eating', | |
'surrounding': 'surrounds' | |
} | |
def make_filter(name, var): | |
return name + '(' + var + ')' | |
def build_conjuntion(result, dep): | |
result.append(dep) | |
for d in dep.dependencies: | |
build_conjuntion(result, d) | |
class Node: | |
def __init__(self, dependencies, variables): | |
self.dependencies = dependencies | |
self.variables = variables | |
def __str__(self): | |
return 'Node' | |
def build_expression(self): | |
conj = [] | |
for d in self.dependencies: | |
tmp = [] | |
build_conjuntion(tmp, d) | |
conj.extend(tmp) | |
return str(self) + ',' + ','.join([str(x) for x in conj]) | |
class Filter(Node): | |
def __init__(self, filter_type, name, variables, dependencies=None): | |
super().__init__(dependencies, variables) | |
self.name = name | |
self.filter_type = filter_type | |
def __str__(self): | |
return self.filter_type + '({0}, {1})'.format(self.name, self.variables[0]) | |
class Exists(Node): | |
def __init__(self, variables, dependencies=None): | |
super().__init__(dependencies, variables) | |
assert len(dependencies) == 1 | |
def __str__(self): | |
return 'exists({0})'.format(self.variables[0]) | |
class Disjunction(Node): | |
def __str__(self): | |
return 'Or(' + ','.join(str(d) for d in self.dependencies) + ')' | |
def build_expression(self): | |
conjunts = [] | |
for dep in self.dependencies: | |
tmp = [] | |
build_conjuntion(tmp, dep) | |
conjunts.append(tmp) | |
str_conj = [','.join(str(c) for c in conj) for conj in conjunts] | |
assert(len(str_conj) == 2) | |
return '({0});({1})'.format(*str_conj) | |
class Conjunction(Node): | |
def __str__(self): | |
return ','.join(str(d) for d in self.dependencies) | |
def build_expression(self): | |
conj = [] | |
for d in self.dependencies: | |
tmp = [] | |
build_conjuntion(tmp, d) | |
conj.extend(tmp) | |
return ','.join([str(x) for x in conj]) | |
class Relation(Node): | |
def __init__(self, rel_name, args, dependencies, variables): | |
super().__init__(dependencies, variables=variables) | |
self.relation = rel_name | |
self.args = args | |
def __str__(self): | |
return self.relation + '({0}, {1})'.format(*[str(d) for d in self.args]) | |
class Verify(Node): | |
def __init__(self, verify_type, verify_arg, dependencies, variables): | |
super().__init__(dependencies=dependencies, variables=variables) | |
self.verify_type = verify_type | |
self.verify_arg = verify_arg | |
assert len(dependencies) == 1 | |
def __str__(self): | |
return 'verify_{0}({1}, {2})'.format(self.verify_type, self.verify_arg, self.variables[0]) | |
def build_expression(self): | |
tmp = [] | |
build_conjuntion(tmp, self.dependencies[0]) | |
return str(self) + ',' + ','.join([str(x) for x in tmp]) | |
class Query(Node): | |
def __init__(self, arg, dependencies, variables): | |
super().__init__(dependencies=dependencies, variables=variables) | |
self.arg = arg | |
assert len(dependencies) == 1 | |
def __str__(self): | |
return 'query({0}, {1})'.format(self.arg, self.variables[0]) | |
class Difference(Node): | |
_name = 'different' | |
def __init__(self, arg, dependencies, variables): | |
super().__init__(dependencies=dependencies, variables=variables) | |
self.arg = arg | |
def __str__(self): | |
if len(self.variables) == 2: | |
return self._name + '({0}, {1}, {2})'.format(self.arg, *self.variables) | |
else: | |
assert len(self.variables) == 1 | |
return self._name + '({0}, {1})'.format(self.arg, self.variables[0]) | |
class Same(Difference): | |
_name = 'same' | |
class Common(Node): | |
def __init__(self, dependencies, variables): | |
super().__init__(dependencies=dependencies, variables=variables) | |
assert len(variables) == 2 | |
def __str__(self): | |
return 'query_common({0}, {1})'.format(*self.variables) | |
def build_relate(argument, dependencies, deps, variables, no_obj): | |
args = argument.split(',') | |
assert len(args) == 3 | |
rel_type, obj_id = s_re.match(args[2]).groups() | |
var_name = get_var_name(no_obj, obj_id, variables) | |
assert len(dependencies) == 1 | |
vars = deps[0].variables | |
#assert(len(vars) == 1) | |
if args[0] != '_': | |
f = Filter(filter_type='object', name=args[0], variables=[var_name], dependencies=deps) | |
depend = [f] | |
else: | |
depend = deps | |
if rel_type == 's': | |
vars = [var_name, vars[0]] | |
rel_args = [vars[0], vars[1]] | |
else: | |
vars = [var_name, vars[0]] | |
rel_args = [vars[1], vars[0]] | |
#print(args[0], rel_type) | |
return Relation(relations[args[1]], rel_args, | |
dependencies=depend, variables=vars) | |
def convert(items, ops, variables, no_obj): | |
if not items: | |
return ops[-1].build_expression() | |
item = items[0] | |
items = items[1:] | |
operation = item['operation'].strip().split() | |
dependencies = item['dependencies'] | |
deps = [ops[i] for i in dependencies] | |
argument = item['argument'] | |
if operation[0] == 'select': | |
if argument: | |
for reg in [two_args_re, two_args_re_no, one_word]: | |
m = reg.match(argument) | |
if m: | |
name, obj_id = m.groups() | |
var_name = get_var_name(no_obj, obj_id, variables) | |
ops.append(Filter('object', name.strip(), [var_name], deps)) | |
break | |
if m is None: | |
m = many_objects.match(argument) | |
name , obj_ids = m.groups()[:2] | |
# match objects to a list! | |
var_name = get_var_name(no_obj, obj_ids, variables) | |
ops.append(Filter('object', name.strip(), ['[{0}]'.format(var_name)], deps)) | |
else: | |
import pdb;pdb.set_trace() | |
elif operation[0] == 'filter': | |
if len(operation) == 1: | |
operation.append('is') | |
assert len(operation) == 2 | |
assert len(dependencies) == 1 | |
vars = deps[0].variables | |
ops.append(Filter(filter_type=operation[1], name=argument, variables=vars, dependencies=deps)) | |
elif operation[0] == 'exist': | |
assert len(dependencies) == 1 | |
vars = deps[0].variables | |
ops.append(Exists(dependencies=deps, variables=vars)) | |
elif operation[0] == 'or': | |
vars = extract_deps(deps) | |
ops.append(Disjunction(dependencies=deps, variables=vars)) | |
elif operation[0] == 'relate': | |
ops.append(build_relate(argument, dependencies, deps, variables, no_obj)) | |
elif operation[0] == 'verify': | |
if len(operation) == 1: | |
operation.append('is') | |
if operation[1] == 'rel': | |
ops.append(build_relate(argument, dependencies, deps, variables, no_obj)) | |
else: | |
ops.append(Verify(verify_type=operation[1], verify_arg=argument, | |
variables=deps[0].variables, dependencies=deps)) | |
elif operation[0] == 'and': | |
vars = extract_deps(deps) | |
ops.append(Conjunction(dependencies=deps, variables=vars)) | |
elif operation[0] == 'query': | |
ops.append(Query(argument, deps, deps[0].variables)) | |
elif operation[0] == 'choose': | |
tmp = [] | |
if len(operation) == 1: | |
operation.append('is') | |
assert len(deps) == 1 | |
if operation[1] == 'rel': | |
left, middle, right = argument.split(',') | |
template = '{0},{1},{2}' | |
for rel_arg in middle.split('|'): | |
new_argument = template.format(left, rel_arg, right) | |
op1 = build_relate(new_argument, dependencies, deps, variables, no_obj) | |
tmp.append(op1) | |
else: | |
for arg in argument.split('|'): | |
op1 = Verify(verify_type=operation[1], verify_arg=arg, | |
variables=deps[0].variables, dependencies=deps) | |
tmp.append(op1) | |
vars = extract_deps(deps) | |
ops.append(Disjunction(dependencies=tmp, variables=vars)) | |
elif operation[0] == 'different': | |
vars = same_difference_params(argument, deps, operation) | |
ops.append(Difference(operation[1], dependencies=deps, variables=vars)) | |
elif operation[0] == 'same': | |
vars = same_difference_params(argument, deps, operation) | |
ops.append(Same(operation[1], dependencies=deps, variables=vars)) | |
elif operation[0] == 'common': | |
vars = deps[0].variables[0], deps[1].variables[0] | |
ops.append(Common(dependencies=deps, variables=vars)) | |
else: | |
import pdb;pdb.set_trace() | |
return convert(items, ops, variables, no_obj) | |
def same_difference_params(argument, deps, operation): | |
if len(deps) == 2: | |
vars = deps[0].variables[0], deps[1].variables[0] | |
assert all(not x.startswith('[') for x in vars) | |
if len(deps) == 1: | |
vars = [deps[0].variables[0]] | |
assert vars[0].startswith('[') | |
if len(operation) == 1: | |
operation.append(argument) | |
return vars | |
def extract_deps(deps): | |
vars = [] | |
for d in deps: | |
for var in d.variables: | |
if var not in vars: | |
vars.append(var) | |
return vars | |
def get_var_name(no_obj, obj_id, variables): | |
if (not obj_id) or (obj_id == '-'): | |
obj_id = obj_id + str(no_obj[0]) | |
no_obj[0] += 1 | |
if obj_id not in variables: | |
variables[obj_id] = varnames[len(variables)] | |
var_name = variables[obj_id] | |
return var_name | |
def main(): | |
import sys | |
if len(sys.argv) != 2: | |
print('give path to TRAIN json file as an argument') | |
path = sys.argv[1] | |
data = json.load(open(path)) | |
for i, key in enumerate(data.keys()): | |
sem = data[key]['semantic'] | |
variables = dict() | |
no_obj_count = [0] | |
items = [] | |
print('{0}: '.format(i), data[key]['question']) | |
res = convert(sem, items, variables, no_obj_count) | |
print('{0}: '.format(i), res) | |
if i > 10000: | |
break | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment