Skip to content

Instantly share code, notes, and snippets.

@jcjohnson
Created April 13, 2019 22:31
Show Gist options
  • Save jcjohnson/6fb119a0372166ec9f4f006a1242a7bc to your computer and use it in GitHub Desktop.
Save jcjohnson/6fb119a0372166ec9f4f006a1242a7bc to your computer and use it in GitHub Desktop.
import argparse, json, os, itertools, random, shutil
import time
import re
import question_engine as qeng
parser = argparse.ArgumentParser()
parser.add_argument('--scene_dir', default='data/rubber_metal_100/scenes')
parser.add_argument('--scene_file', default='')
parser.add_argument('--scene_start_idx', default=0, type=int)
parser.add_argument('--num_scenes', default=100, type=int)
parser.add_argument('--metadata_file', default='metadata.json')
parser.add_argument('--filter_options_file', default='data/VG/filter_options.json')
parser.add_argument('--template_dir', default='templates3/')
parser.add_argument('--instances_per_template', default=1, type=int)
parser.add_argument('--templates_per_image', default=10, type=int)
parser.add_argument('--reset_counts_every', default=250, type=int)
parser.add_argument('--output_json', default='questions.json')
parser.add_argument('--profile', default=0, type=int)
parser.add_argument('--use_local_copies', default=0, type=int)
parser.add_argument('--fbcode_imports', default=0, type=int)
args = parser.parse_args()
def expand_template_options(template):
new_templates = []
optional_idxs = [i for i, n in enumerate(template['nodes']) if n['optional']]
for r in range(len(optional_idxs) + 1):
for c in itertools.combinations(optional_idxs, r):
nodes_to_exclude = set(c)
new_template = {
'nodes': [],
'params': [],
'text': [],
'constraints': [],
}
# Add nodes to the new template, and also figure out which template
# params should be excluded.
params_to_exclude = set()
node_idx_map = {} # maps old idx to new idx
for node_idx, node in enumerate(template['nodes']):
if node_idx not in nodes_to_exclude:
new_node = {
'type': node['type'],
'inputs': [node_idx_map[idx] for idx in node['inputs']],
}
if 'side_inputs' in node:
new_node['side_inputs'] = node['side_inputs']
node_idx_map[node_idx] = len(new_template['nodes'])
new_template['nodes'].append(new_node)
elif 'side_inputs' in node:
for param_name in node['side_inputs']:
params_to_exclude.add(param_name)
node_idx_map[node_idx] = node_idx_map[node['inputs'][0]]
# Add params to the new template
for param in template['params']:
if param['name'] not in params_to_exclude:
new_template['params'].append(param)
# Add text to the new template
for txt in template['text']:
for param in template['params']:
if param['name'] in params_to_exclude:
if param['type'] == 'Shape':
txt = txt.replace(param['name'], 'thing')
else:
txt = txt.replace(param['name'], '')
txt = ' '.join(txt.split())
new_template['text'].append(txt)
# Add constraints to the new template; skip any that contain params
# that were excluded.
for constraint in template['constraints']:
should_include = True
for param_name in constraint['params']:
if param_name in params_to_exclude:
should_include = False
break
if should_include:
new_template['constraints'].append(constraint)
new_templates.append(new_template)
return new_templates
def precompute_filter_options(scene_struct, metadata):
# Keys are tuples (size, color, shape, material) (where some may be None)
# and values are lists of object idxs that match the filter criterion
attribute_map = {}
if metadata['dataset'] == 'shapes':
attr_keys = ['size', 'color', 'material', 'shape']
elif metadata['dataset'] == 'visual-genome':
attr_keys = ['color', 'material', 'objectcategory']
else:
assert False, 'Unrecognized dataset'
# Precompute masks
masks = []
for i in range(2 ** len(attr_keys)):
mask = []
for j in range(len(attr_keys)):
mask.append((i // (2 ** j)) % 2)
masks.append(mask)
for object_idx, obj in enumerate(scene_struct['objects']):
if metadata['dataset'] == 'shapes':
keys = [tuple(obj[k] for k in attr_keys)]
elif metadata['dataset'] == 'visual-genome':
keys = list(itertools.product(*[obj[k] + [None] for k in attr_keys]))
keys = [tuple(k) for k in keys]
for mask in masks:
for key in keys:
masked_key = []
for a, b in zip(key, mask):
if b == 1:
masked_key.append(a)
else:
masked_key.append(None)
masked_key = tuple(masked_key)
if masked_key not in attribute_map:
attribute_map[masked_key] = set()
attribute_map[masked_key].add(object_idx)
scene_struct['_filter_options'] = attribute_map
def find_filter_options(object_idxs, scene_struct, metadata):
# Keys are tuples (size, color, shape, material) (where some may be None)
# and values are lists of object idxs that match the filter criterion
if '_filter_options' not in scene_struct:
precompute_filter_options(scene_struct, metadata)
attribute_map = {}
object_idxs = set(object_idxs)
for k, vs in scene_struct['_filter_options'].items():
attribute_map[k] = sorted(list(object_idxs & vs))
return attribute_map
return attribute_map
def add_empty_filter_options(attribute_map, metadata, num_to_add):
# Add some filtering criterion that do NOT correspond to objects
if metadata['dataset'] == 'shapes':
attr_keys = ['Size', 'Color', 'Material', 'Shape']
elif metadata['dataset'] == 'visual-genome':
attr_keys = ['Color', 'Material', 'ObjectCategory']
else:
assert False, 'Unrecognized dataset'
attr_vals = [metadata['types'][t] + [None] for t in attr_keys]
if '_filter_options' in metadata:
attr_vals = metadata['_filter_options']
target_size = len(attribute_map) + num_to_add
while len(attribute_map) < target_size:
k = (random.choice(v) for v in attr_vals)
if k not in attribute_map:
attribute_map[k] = []
def find_relate_filter_options(object_idx, scene_struct, metadata,
unique=False, include_zero=False, trivial_frac=0.1):
options = {}
if '_filter_options' not in scene_struct:
precompute_filter_options(scene_struct, metadata)
if '_all_relationships' not in scene_struct:
scene_struct['_all_relationships'] = qeng.compute_all_relationships(scene_struct)
# TODO: Right now this is only looking for nontrivial combinations; in some
# cases I may want to add trivial combinations, either where the intersection
# is empty or where the intersection is equal to the filtering output.
trivial_options = {}
for relationship in scene_struct['_all_relationships']:
related = set(scene_struct['_all_relationships'][relationship][object_idx])
for filters, filtered in scene_struct['_filter_options'].items():
intersection = related & filtered
trivial = (intersection == filtered)
if unique and len(intersection) != 1: continue
if not include_zero and len(intersection) == 0: continue
if trivial:
trivial_options[(relationship, filters)] = sorted(list(intersection))
else:
options[(relationship, filters)] = sorted(list(intersection))
N, f = len(options), trivial_frac
num_trivial = round(N * f / (1 - f))
trivial_options = list(trivial_options.items())
random.shuffle(trivial_options)
for k, v in trivial_options[:num_trivial]:
options[k] = v
return options
def node_shallow_copy(node):
new_node = {
'type': node['type'],
'inputs': node['inputs'],
}
if 'side_inputs' in node:
new_node['side_inputs'] = node['side_inputs']
return new_node
def other_heuristic(text, param_vals):
"""
Post-processing heuristic to handle the word "other"
"""
if ' other ' not in text and ' another ' not in text:
return text
target_keys = {
'<Z>', '<C>', '<M>', '<S>',
'<Z2>', '<C2>', '<M2>', '<S2>',
}
if param_vals.keys() != target_keys:
return text
key_pairs = [
('<Z>', '<Z2>'),
('<C>', '<C2>'),
('<M>', '<M2>'),
('<S>', '<S2>'),
]
remove_other = False
for k1, k2 in key_pairs:
v1 = param_vals.get(k1, None)
v2 = param_vals.get(k2, None)
if v1 != '' and v2 != '' and v1 != v2:
print('other has got to go! %s = %s but %s = %s'
% (k1, v1, k2, v2))
remove_other = True
break
if remove_other:
if ' other ' in text:
text = text.replace(' other ', ' ')
if ' another ' in text:
text = text.replace(' another ', ' a ')
return text
def instantiate_templates_dfs(scene_struct, template, metadata, answer_counts,
max_instances=None, verbose=False):
param_name_to_type = {p['name']: p['type'] for p in template['params']}
initial_state = {
'nodes': [node_shallow_copy(template['nodes'][0])],
'vals': {},
'input_map': {0: 0},
'next_template_node': 1,
}
states = [initial_state]
final_states = []
while states:
state = states.pop()
# Check to make sure the current state is valid
q = {'nodes': state['nodes']}
outputs = qeng.answer_question(q, metadata, scene_struct, all_outputs=True)
answer = outputs[-1]
if answer == '__INVALID__': continue
# Check to make sure constraints are satisfied for the current state
skip_state = False
for constraint in template['constraints']:
if constraint['type'] == 'NEQ':
p1, p2 = constraint['params']
v1, v2 = state['vals'].get(p1), state['vals'].get(p2)
if v1 is not None and v2 is not None and v1 != v2:
if verbose:
print('skipping due to NEQ constraint')
print(constraint)
print(state['vals'])
skip_state = True
break
elif constraint['type'] == 'NULL':
p = constraint['params'][0]
p_type = param_name_to_type[p]
v = state['vals'].get(p)
if v is not None:
skip = False
if p_type == 'Shape' and v != 'thing': skip = True
if p_type != 'Shape' and v != '': skip = True
if skip:
if verbose:
print('skipping due to NULL constraint')
print(constraint)
print(state['vals'])
skip_state = True
break
elif constraint['type'] == 'OUT_NEQ':
i, j = constraint['params']
i = state['input_map'].get(i, None)
j = state['input_map'].get(j, None)
if i is not None and j is not None and outputs[i] == outputs[j]:
if verbose:
print('skipping due to OUT_NEQ constraint')
print(outputs[i])
print(outputs[j])
skip_state = True
break
else:
assert False, 'Unrecognized constraint type "%s"' % constraint['type']
if skip_state:
continue
# We have already checked to make sure the answer is valid, so if we have
# processed all the nodes in the template then the current state is a valid
# question, so add it if it passes our rejection sampling tests.
if state['next_template_node'] == len(template['nodes']):
# Use our rejection sampling heuristics to decide whether we should
# keep this template instantiation
cur_answer_count = answer_counts[answer]
answer_counts_sorted = sorted(answer_counts.values())
median_count = answer_counts_sorted[len(answer_counts_sorted) // 2]
median_count = max(median_count, 5)
if cur_answer_count > 1.1 * answer_counts_sorted[-2]:
if verbose: print('skipping due to second count')
continue
if cur_answer_count > 5.0 * median_count:
if verbose: print('skipping due to median')
continue
# If the template contains a raw relate node then we need to check for
# degeneracy at the end
has_relate = any(n['type'] == 'relate' for n in template['nodes'])
if has_relate:
degen = qeng.is_degenerate(q, metadata, scene_struct, answer=answer,
verbose=verbose)
if degen:
continue
# if qeng.is_degenerate(q, metadata, scene_struct, answer=answer,
# verbose=verbose):
# if verbose:
# print('skipping due to degenerecy')
# print('scene')
# for o in scene_struct['objects']:
# print(o['size'], o['color'], o['material'], o['shape'])
# print()
# print('question')
# for i, n in enumerate(q['nodes']):
# name = n['type']
# if 'side_inputs' in n:
# name = '%s[%s]' % (name, n['side_inputs'][0])
# print(i, name, n['_output'])
# print()
# continue
answer_counts[answer] += 1
state['answer'] = answer
final_states.append(state)
if max_instances is not None and len(final_states) == max_instances:
break
continue
# Otherwise fetch the next node from the template
# Make a shallow copy so cached _outputs don't leak ... this is very nasty
next_node = template['nodes'][state['next_template_node']]
next_node = node_shallow_copy(next_node)
special_nodes = {
'filter_unique', 'filter_count', 'filter_exist', 'filter',
'relate_filter', 'relate_filter_unique', 'relate_filter_count',
'relate_filter_exist',
}
if next_node['type'] in special_nodes:
if next_node['type'].startswith('relate_filter'):
unique = (next_node['type'] == 'relate_filter_unique')
include_zero = (next_node['type'] == 'relate_filter_count'
or next_node['type'] == 'relate_filter_exist')
filter_options = find_relate_filter_options(answer, scene_struct, metadata,
unique=unique, include_zero=include_zero)
else:
filter_options = find_filter_options(answer, scene_struct, metadata)
if next_node['type'] == 'filter':
# Remove null filter
filter_options.pop((None, None, None, None), None)
if next_node['type'] == 'filter_unique':
# Get rid of all filter options that don't result in a single object
filter_options = {k: v for k, v in filter_options.items()
if len(v) == 1}
else:
# Add some filter options that do NOT correspond to the scene
if next_node['type'] == 'filter_exist':
# For filter_exist we want an equal number that do and don't
num_to_add = len(filter_options)
elif next_node['type'] == 'filter_count' or next_node['type'] == 'filter':
# For filter_count add nulls equal to the number of singletons
num_to_add = sum(1 for k, v in filter_options.items() if len(v) == 1)
add_empty_filter_options(filter_options, metadata, num_to_add)
filter_option_keys = list(filter_options.keys())
random.shuffle(filter_option_keys)
for k in filter_option_keys:
new_nodes = []
cur_next_vals = {k: v for k, v in state['vals'].items()}
next_input = state['input_map'][next_node['inputs'][0]]
filter_side_inputs = next_node['side_inputs']
if next_node['type'].startswith('relate'):
param_name = next_node['side_inputs'][0] # First one should be relate
filter_side_inputs = next_node['side_inputs'][1:]
param_type = param_name_to_type[param_name]
assert param_type == 'Relation'
param_val = k[0]
k = k[1]
new_nodes.append({
'type': 'relate',
'inputs': [next_input],
'side_inputs': [param_val],
})
cur_next_vals[param_name] = param_val
next_input = len(state['nodes']) + len(new_nodes) - 1
for param_name, param_val in zip(filter_side_inputs, k):
param_type = param_name_to_type[param_name]
filter_type = 'filter_%s' % param_type.lower()
if param_val is not None:
new_nodes.append({
'type': filter_type,
'inputs': [next_input],
'side_inputs': [param_val],
})
cur_next_vals[param_name] = param_val
next_input = len(state['nodes']) + len(new_nodes) - 1
elif param_val is None:
if metadata['dataset'] == 'shapes' and param_type == 'Shape':
param_val = 'thing'
elif metadata['dataset'] == 'visual-genome' and param_type == 'ObjectCategory':
param_val = 'thing'
else:
param_val = ''
cur_next_vals[param_name] = param_val
input_map = {k: v for k, v in state['input_map'].items()}
extra_type = None
if next_node['type'].endswith('unique'):
extra_type = 'unique'
if next_node['type'].endswith('count'):
extra_type = 'count'
if next_node['type'].endswith('exist'):
extra_type = 'exist'
if extra_type is not None:
# if next_node['type'] != 'filter':
new_nodes.append({
# 'type': next_node['type'].split('_')[1],
'type': extra_type,
'inputs': [input_map[next_node['inputs'][0]] + len(new_nodes)],
})
input_map[state['next_template_node']] = len(state['nodes']) + len(new_nodes) - 1
states.append({
'nodes': state['nodes'] + new_nodes,
'vals': cur_next_vals,
'input_map': input_map,
'next_template_node': state['next_template_node'] + 1,
})
elif 'side_inputs' in next_node:
# If the next node has template parameters, expand them out
# TODO: Generalize this to work for nodes with more than one side input
assert len(next_node['side_inputs']) == 1, 'NOT IMPLEMENTED'
# Use metadata to figure out domain of valid values for this parameter.
# Iterate over the values in a random order; then it is safe to bail
# from the DFS as soon as we find the desired number of valid template
# instantiations.
param_name = next_node['side_inputs'][0]
param_type = param_name_to_type[param_name]
param_vals = metadata['types'][param_type][:]
random.shuffle(param_vals)
for val in param_vals:
input_map = {k: v for k, v in state['input_map'].items()}
input_map[state['next_template_node']] = len(state['nodes'])
cur_next_node = {
'type': next_node['type'],
'inputs': [input_map[idx] for idx in next_node['inputs']],
'side_inputs': [val],
}
cur_next_vals = {k: v for k, v in state['vals'].items()}
cur_next_vals[param_name] = val
states.append({
'nodes': state['nodes'] + [cur_next_node],
'vals': cur_next_vals,
'input_map': input_map,
'next_template_node': state['next_template_node'] + 1,
})
else:
input_map = {k: v for k, v in state['input_map'].items()}
input_map[state['next_template_node']] = len(state['nodes'])
next_node = {
'type': next_node['type'],
'inputs': [input_map[idx] for idx in next_node['inputs']],
}
states.append({
'nodes': state['nodes'] + [next_node],
'vals': state['vals'],
'input_map': input_map,
'next_template_node': state['next_template_node'] + 1,
})
if metadata['dataset'] == 'shapes':
synonyms = {
'thing': ['thing', 'object'],
'sphere': ['sphere', 'ball'],
'cube': ['cube', 'block'],
'large': ['large', 'big'],
'small': ['small', 'tiny'],
'metal': ['metallic', 'metal', 'shiny'],
'rubber': ['rubber', 'matte'],
'left': ['left of', 'to the left of', 'on the left side of'],
'right': ['right of', 'to the right of', 'on the right side of'],
'behind': ['behind'],
'front': ['in front of'],
'above': ['above'],
'below': ['below'],
}
elif metadata['dataset'] == 'visual-genome':
synonyms = {r: r for r in metadata['types']['Relation']}
text_questions, structured_questions, answers = [], [], []
for state in final_states:
structured_questions.append(state['nodes'])
answers.append(state['answer'])
text = random.choice(template['text'])
for name, val in state['vals'].items():
if val in synonyms: val = random.choice(synonyms[val])
text = text.replace(name, val)
text = ' '.join(text.split())
text = replace_optionals(text)
text = ' '.join(text.split())
text = other_heuristic(text, state['vals'])
text_questions.append(text)
return text_questions, structured_questions, answers
def replace_optionals(s):
"""
Each substring of s that is surrounded in square brackets is treated as
optional and is removed with probability 0.5. For example the string
"A [aa] B [bb]"
could become any of
"A aa B bb"
"A B bb"
"A aa B "
"A B "
with probability 1/4.
"""
pat = re.compile(r'\[([^\[]*)\]')
while True:
match = re.search(pat, s)
if not match:
break
i0 = match.start()
i1 = match.end()
if random.random() > 0.5:
s = s[:i0] + match.groups()[0] + s[i1:]
else:
s = s[:i0] + s[i1:]
return s
def main(args):
with open(args.metadata_file, 'r') as f:
metadata = json.load(f)
if metadata['dataset'] == 'visual-genome':
with open(args.filter_options_file, 'r') as f:
metadata['_filter_options'] = json.load(f)
functions_by_name = {}
for f in metadata['functions']:
functions_by_name[f['name']] = f
metadata['_functions_by_name'] = functions_by_name
# Key is (filename, file_idx, expansion_idx)
num_loaded_templates = 0
num_expanded_templates = 0
templates = {}
for fn in os.listdir(args.template_dir):
if not fn.endswith('.json'): continue
with open(os.path.join(args.template_dir, fn), 'r') as f:
base = os.path.splitext(fn)[0]
for i, template in enumerate(json.load(f)):
num_loaded_templates += 1
# expanded = expand_template_options(template)
expanded = [template]
num_expanded_templates += len(expanded)
for j, t in enumerate(expanded):
key = (fn, i, j)
templates[key] = t
print('Read %d templates from disk' % num_loaded_templates)
print('After expansion there are %d templates' % num_expanded_templates)
def reset_counts():
# Maps a template (filename, index) to the number of questions we have
# so far using that template
template_counts = {}
# Maps a template (filename, index) to a dict mapping the answer to the
# number of questions so far of that template type with that answer
template_answer_counts = {}
node_type_to_dtype = {n['name']: n['output'] for n in metadata['functions']}
for key, template in templates.items():
template_counts[key[:2]] = 0
final_node_type = template['nodes'][-1]['type']
final_dtype = node_type_to_dtype[final_node_type]
answers = metadata['types'][final_dtype]
if final_dtype == 'Bool':
answers = [True, False]
if final_dtype == 'Integer':
if metadata['dataset'] == 'shapes':
answers = list(range(0, 11))
elif metadata['dataset'] == 'visual-genome':
answers = list(range(0, 150))
template_answer_counts[key[:2]] = {}
for a in answers:
template_answer_counts[key[:2]][a] = 0
return template_counts, template_answer_counts
template_counts, template_answer_counts = reset_counts()
assert (args.scene_dir == '') != (args.scene_file == '')
all_scenes = []
if args.scene_dir != '':
for fn in os.listdir(args.scene_dir):
if not fn.endswith('.json'):
continue
with open(os.path.join(args.scene_dir, fn), 'r') as f:
all_scenes.append({
'filename': fn,
'scene': json.load(f),
})
if args.num_scenes > 0:
all_scenes = all_scenes[:args.num_scenes]
elif args.scene_file != '':
if args.use_local_copies == 0:
scene_file = args.scene_file
else:
print('copying scene file to local directory...')
shutil.copy(args.scene_file, '_scene_file.json')
scene_file = '_scene_file.json'
print('done')
print('reading scene data from ', scene_file)
with open(scene_file, 'r') as f:
all_scenes = json.load(f)
print('done')
begin = args.scene_start_idx
if args.num_scenes > 0:
end = args.scene_start_idx + args.num_scenes
all_scenes = all_scenes[begin:end]
else:
all_scenes = all_scenes[begin:]
questions = []
scene_count = 0
for i, scene in enumerate(all_scenes):
scene_fn = scene['filename']
scene_struct = scene['scene']
print('starting image %s (%d / %d)'
% (scene_fn, i + 1, len(all_scenes)))
print('number of objects: ', len(scene_struct['objects']))
if scene_count % args.reset_counts_every == 0:
print('resetting counts')
template_counts, template_answer_counts = reset_counts()
scene_count += 1
# Order templates by the number of questions we have so far for those
# templates. This is a simple heuristic to give a flat distribution over
# templates.
templates_items = list(templates.items())
templates_items = sorted(templates_items,
key=lambda x: template_counts[x[0][:2]])
num_instantiated = 0
for (fn, idx, jdx), template in templates_items:
# print('starting template ', fn, idx)
print('trying template ', fn, idx)
# verbose = fn == 'double_relate.json' and idx == 1
verbose = False
tic = time.time()
ts, qs, ans = instantiate_templates_dfs(
scene_struct,
template,
metadata,
template_answer_counts[(fn, idx)],
max_instances=args.instances_per_template,
verbose=verbose)
toc = time.time()
print('that took ', toc - tic)
for t, q, a in zip(ts, qs, ans):
questions.append({
'image': os.path.splitext(scene_fn)[0],
'text_question': t,
'structured_question': q,
'answer': a,
'template_filename': fn,
'template_idx': idx,
'template_expand_idx': jdx,
'question_id': len(questions),
})
if len(ts) > 0:
print('got one!')
num_instantiated += 1
template_counts[(fn, idx)] += 1
else:
print('did not get any =(')
if num_instantiated >= args.templates_per_image:
break
with open(args.output_json, 'w') as f:
json.dump(questions, f)
if __name__ == '__main__':
if args.profile == 1:
import cProfile
cProfile.run('main(args)')
else:
main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment