Skip to content

Instantly share code, notes, and snippets.

@noskill
Last active February 18, 2020 13:40
Show Gist options
  • Save noskill/3ea20a3331b30e02d5799dfa9f82c558 to your computer and use it in GitHub Desktop.
Save noskill/3ea20a3331b30e02d5799dfa9f82c558 to your computer and use it in GitHub Desktop.
GQA to logic
import re
from collections import defaultdict
import json
two_args_re = re.compile('^([\w\s]+)\((\d+)\)')
two_args_re_no = re.compile('^([\w\s]+)\((-)\)')
many_objects = re.compile('^(\w+)\s\((\d+(\,\d+)+)\)')
one_word = re.compile('^(\w+)($)')
s_re = re.compile('([s|o])\s\((.*)\)')
varnames = ['$X', '$Y', '$Z', '$E', '$R']
relations = {
'to the right of': 'right_of',
'on': 'on',
'to the left of': 'left_of',
'by': 'by',
'in': 'in',
'carrying': 'carrying',
'behind': 'behind',
'wearing': 'wearing',
'near': 'near',
'in front of': 'front_of',
'with': 'with',
'full of': 'full_of',
'on top of': 'on_top_of',
'above': 'above',
'sitting on': 'sits_on',
'covered in': 'covered_in',
'of': 'of',
'feeding': 'feeding',
'holding': 'holding',
'drinking from': 'drinks_from',
'lying in': 'lays_in',
'riding': 'rides',
'using': 'uses',
'lying on': 'lays_on',
'below': 'below',
'sitting on top of': 'sits_on_top',
'hanging on': 'hangs_on',
'leaning against': 'leans_against',
'kicking': 'kicks',
'next to': 'next_to',
'eating': 'eating',
'surrounding': 'surrounds'
}
def make_filter(name, var):
return name + '(' + var + ')'
def build_conjuntion(result, dep):
result.append(dep)
for d in dep.dependencies:
build_conjuntion(result, d)
class Node:
def __init__(self, dependencies, variables):
self.dependencies = dependencies
self.variables = variables
def __str__(self):
return 'Node'
def build_expression(self):
conj = []
for d in self.dependencies:
tmp = []
build_conjuntion(tmp, d)
conj.extend(tmp)
return str(self) + ',' + ','.join([str(x) for x in conj])
class Filter(Node):
def __init__(self, filter_type, name, variables, dependencies=None):
super().__init__(dependencies, variables)
self.name = name
self.filter_type = filter_type
def __str__(self):
return self.filter_type + '({0}, {1})'.format(self.name, self.variables[0])
class Exists(Node):
def __init__(self, variables, dependencies=None):
super().__init__(dependencies, variables)
assert len(dependencies) == 1
def __str__(self):
return 'exists({0})'.format(self.variables[0])
class Disjunction(Node):
def __str__(self):
return 'Or(' + ','.join(str(d) for d in self.dependencies) + ')'
def build_expression(self):
conjunts = []
for dep in self.dependencies:
tmp = []
build_conjuntion(tmp, dep)
conjunts.append(tmp)
str_conj = [','.join(str(c) for c in conj) for conj in conjunts]
assert(len(str_conj) == 2)
return '({0});({1})'.format(*str_conj)
class Conjunction(Node):
def __str__(self):
return ','.join(str(d) for d in self.dependencies)
def build_expression(self):
conj = []
for d in self.dependencies:
tmp = []
build_conjuntion(tmp, d)
conj.extend(tmp)
return ','.join([str(x) for x in conj])
class Relation(Node):
def __init__(self, rel_name, args, dependencies, variables):
super().__init__(dependencies, variables=variables)
self.relation = rel_name
self.args = args
def __str__(self):
return self.relation + '({0}, {1})'.format(*[str(d) for d in self.args])
class Verify(Node):
def __init__(self, verify_type, verify_arg, dependencies, variables):
super().__init__(dependencies=dependencies, variables=variables)
self.verify_type = verify_type
self.verify_arg = verify_arg
assert len(dependencies) == 1
def __str__(self):
return 'verify_{0}({1}, {2})'.format(self.verify_type, self.verify_arg, self.variables[0])
def build_expression(self):
tmp = []
build_conjuntion(tmp, self.dependencies[0])
return str(self) + ',' + ','.join([str(x) for x in tmp])
class Query(Node):
def __init__(self, arg, dependencies, variables):
super().__init__(dependencies=dependencies, variables=variables)
self.arg = arg
assert len(dependencies) == 1
def __str__(self):
return 'query({0}, {1})'.format(self.arg, self.variables[0])
class Difference(Node):
_name = 'different'
def __init__(self, arg, dependencies, variables):
super().__init__(dependencies=dependencies, variables=variables)
self.arg = arg
def __str__(self):
if len(self.variables) == 2:
return self._name + '({0}, {1}, {2})'.format(self.arg, *self.variables)
else:
assert len(self.variables) == 1
return self._name + '({0}, {1})'.format(self.arg, self.variables[0])
class Same(Difference):
_name = 'same'
class Common(Node):
def __init__(self, dependencies, variables):
super().__init__(dependencies=dependencies, variables=variables)
assert len(variables) == 2
def __str__(self):
return 'query_common({0}, {1})'.format(*self.variables)
def build_relate(argument, dependencies, deps, variables, no_obj):
args = argument.split(',')
assert len(args) == 3
rel_type, obj_id = s_re.match(args[2]).groups()
var_name = get_var_name(no_obj, obj_id, variables)
assert len(dependencies) == 1
vars = deps[0].variables
#assert(len(vars) == 1)
if args[0] != '_':
f = Filter(filter_type='object', name=args[0], variables=[var_name], dependencies=deps)
depend = [f]
else:
depend = deps
if rel_type == 's':
vars = [var_name, vars[0]]
rel_args = [vars[0], vars[1]]
else:
vars = [var_name, vars[0]]
rel_args = [vars[1], vars[0]]
#print(args[0], rel_type)
return Relation(relations[args[1]], rel_args,
dependencies=depend, variables=vars)
def convert(items, ops, variables, no_obj):
if not items:
return ops[-1].build_expression()
item = items[0]
items = items[1:]
operation = item['operation'].strip().split()
dependencies = item['dependencies']
deps = [ops[i] for i in dependencies]
argument = item['argument']
if operation[0] == 'select':
if argument:
for reg in [two_args_re, two_args_re_no, one_word]:
m = reg.match(argument)
if m:
name, obj_id = m.groups()
var_name = get_var_name(no_obj, obj_id, variables)
ops.append(Filter('object', name.strip(), [var_name], deps))
break
if m is None:
m = many_objects.match(argument)
name , obj_ids = m.groups()[:2]
# match objects to a list!
var_name = get_var_name(no_obj, obj_ids, variables)
ops.append(Filter('object', name.strip(), ['[{0}]'.format(var_name)], deps))
else:
import pdb;pdb.set_trace()
elif operation[0] == 'filter':
if len(operation) == 1:
operation.append('is')
assert len(operation) == 2
assert len(dependencies) == 1
vars = deps[0].variables
ops.append(Filter(filter_type=operation[1], name=argument, variables=vars, dependencies=deps))
elif operation[0] == 'exist':
assert len(dependencies) == 1
vars = deps[0].variables
ops.append(Exists(dependencies=deps, variables=vars))
elif operation[0] == 'or':
vars = extract_deps(deps)
ops.append(Disjunction(dependencies=deps, variables=vars))
elif operation[0] == 'relate':
ops.append(build_relate(argument, dependencies, deps, variables, no_obj))
elif operation[0] == 'verify':
if len(operation) == 1:
operation.append('is')
if operation[1] == 'rel':
ops.append(build_relate(argument, dependencies, deps, variables, no_obj))
else:
ops.append(Verify(verify_type=operation[1], verify_arg=argument,
variables=deps[0].variables, dependencies=deps))
elif operation[0] == 'and':
vars = extract_deps(deps)
ops.append(Conjunction(dependencies=deps, variables=vars))
elif operation[0] == 'query':
ops.append(Query(argument, deps, deps[0].variables))
elif operation[0] == 'choose':
tmp = []
if len(operation) == 1:
operation.append('is')
assert len(deps) == 1
if operation[1] == 'rel':
left, middle, right = argument.split(',')
template = '{0},{1},{2}'
for rel_arg in middle.split('|'):
new_argument = template.format(left, rel_arg, right)
op1 = build_relate(new_argument, dependencies, deps, variables, no_obj)
tmp.append(op1)
else:
for arg in argument.split('|'):
op1 = Verify(verify_type=operation[1], verify_arg=arg,
variables=deps[0].variables, dependencies=deps)
tmp.append(op1)
vars = extract_deps(deps)
ops.append(Disjunction(dependencies=tmp, variables=vars))
elif operation[0] == 'different':
vars = same_difference_params(argument, deps, operation)
ops.append(Difference(operation[1], dependencies=deps, variables=vars))
elif operation[0] == 'same':
vars = same_difference_params(argument, deps, operation)
ops.append(Same(operation[1], dependencies=deps, variables=vars))
elif operation[0] == 'common':
vars = deps[0].variables[0], deps[1].variables[0]
ops.append(Common(dependencies=deps, variables=vars))
else:
import pdb;pdb.set_trace()
return convert(items, ops, variables, no_obj)
def same_difference_params(argument, deps, operation):
if len(deps) == 2:
vars = deps[0].variables[0], deps[1].variables[0]
assert all(not x.startswith('[') for x in vars)
if len(deps) == 1:
vars = [deps[0].variables[0]]
assert vars[0].startswith('[')
if len(operation) == 1:
operation.append(argument)
return vars
def extract_deps(deps):
vars = []
for d in deps:
for var in d.variables:
if var not in vars:
vars.append(var)
return vars
def get_var_name(no_obj, obj_id, variables):
if (not obj_id) or (obj_id == '-'):
obj_id = obj_id + str(no_obj[0])
no_obj[0] += 1
if obj_id not in variables:
variables[obj_id] = varnames[len(variables)]
var_name = variables[obj_id]
return var_name
def main():
import sys
if len(sys.argv) != 2:
print('give path to TRAIN json file as an argument')
path = sys.argv[1]
data = json.load(open(path))
for i, key in enumerate(data.keys()):
sem = data[key]['semantic']
variables = dict()
no_obj_count = [0]
items = []
print('{0}: '.format(i), data[key]['question'])
res = convert(sem, items, variables, no_obj_count)
print('{0}: '.format(i), res)
if i > 10000:
break
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment