-
-
Save jcjohnson/6fb119a0372166ec9f4f006a1242a7bc to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse, json, os, itertools, random, shutil | |
import time | |
import re | |
import question_engine as qeng | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--scene_dir', default='data/rubber_metal_100/scenes') | |
parser.add_argument('--scene_file', default='') | |
parser.add_argument('--scene_start_idx', default=0, type=int) | |
parser.add_argument('--num_scenes', default=100, type=int) | |
parser.add_argument('--metadata_file', default='metadata.json') | |
parser.add_argument('--filter_options_file', default='data/VG/filter_options.json') | |
parser.add_argument('--template_dir', default='templates3/') | |
parser.add_argument('--instances_per_template', default=1, type=int) | |
parser.add_argument('--templates_per_image', default=10, type=int) | |
parser.add_argument('--reset_counts_every', default=250, type=int) | |
parser.add_argument('--output_json', default='questions.json') | |
parser.add_argument('--profile', default=0, type=int) | |
parser.add_argument('--use_local_copies', default=0, type=int) | |
parser.add_argument('--fbcode_imports', default=0, type=int) | |
args = parser.parse_args() | |
def expand_template_options(template): | |
new_templates = [] | |
optional_idxs = [i for i, n in enumerate(template['nodes']) if n['optional']] | |
for r in range(len(optional_idxs) + 1): | |
for c in itertools.combinations(optional_idxs, r): | |
nodes_to_exclude = set(c) | |
new_template = { | |
'nodes': [], | |
'params': [], | |
'text': [], | |
'constraints': [], | |
} | |
# Add nodes to the new template, and also figure out which template | |
# params should be excluded. | |
params_to_exclude = set() | |
node_idx_map = {} # maps old idx to new idx | |
for node_idx, node in enumerate(template['nodes']): | |
if node_idx not in nodes_to_exclude: | |
new_node = { | |
'type': node['type'], | |
'inputs': [node_idx_map[idx] for idx in node['inputs']], | |
} | |
if 'side_inputs' in node: | |
new_node['side_inputs'] = node['side_inputs'] | |
node_idx_map[node_idx] = len(new_template['nodes']) | |
new_template['nodes'].append(new_node) | |
elif 'side_inputs' in node: | |
for param_name in node['side_inputs']: | |
params_to_exclude.add(param_name) | |
node_idx_map[node_idx] = node_idx_map[node['inputs'][0]] | |
# Add params to the new template | |
for param in template['params']: | |
if param['name'] not in params_to_exclude: | |
new_template['params'].append(param) | |
# Add text to the new template | |
for txt in template['text']: | |
for param in template['params']: | |
if param['name'] in params_to_exclude: | |
if param['type'] == 'Shape': | |
txt = txt.replace(param['name'], 'thing') | |
else: | |
txt = txt.replace(param['name'], '') | |
txt = ' '.join(txt.split()) | |
new_template['text'].append(txt) | |
# Add constraints to the new template; skip any that contain params | |
# that were excluded. | |
for constraint in template['constraints']: | |
should_include = True | |
for param_name in constraint['params']: | |
if param_name in params_to_exclude: | |
should_include = False | |
break | |
if should_include: | |
new_template['constraints'].append(constraint) | |
new_templates.append(new_template) | |
return new_templates | |
def precompute_filter_options(scene_struct, metadata): | |
# Keys are tuples (size, color, shape, material) (where some may be None) | |
# and values are lists of object idxs that match the filter criterion | |
attribute_map = {} | |
if metadata['dataset'] == 'shapes': | |
attr_keys = ['size', 'color', 'material', 'shape'] | |
elif metadata['dataset'] == 'visual-genome': | |
attr_keys = ['color', 'material', 'objectcategory'] | |
else: | |
assert False, 'Unrecognized dataset' | |
# Precompute masks | |
masks = [] | |
for i in range(2 ** len(attr_keys)): | |
mask = [] | |
for j in range(len(attr_keys)): | |
mask.append((i // (2 ** j)) % 2) | |
masks.append(mask) | |
for object_idx, obj in enumerate(scene_struct['objects']): | |
if metadata['dataset'] == 'shapes': | |
keys = [tuple(obj[k] for k in attr_keys)] | |
elif metadata['dataset'] == 'visual-genome': | |
keys = list(itertools.product(*[obj[k] + [None] for k in attr_keys])) | |
keys = [tuple(k) for k in keys] | |
for mask in masks: | |
for key in keys: | |
masked_key = [] | |
for a, b in zip(key, mask): | |
if b == 1: | |
masked_key.append(a) | |
else: | |
masked_key.append(None) | |
masked_key = tuple(masked_key) | |
if masked_key not in attribute_map: | |
attribute_map[masked_key] = set() | |
attribute_map[masked_key].add(object_idx) | |
scene_struct['_filter_options'] = attribute_map | |
def find_filter_options(object_idxs, scene_struct, metadata): | |
# Keys are tuples (size, color, shape, material) (where some may be None) | |
# and values are lists of object idxs that match the filter criterion | |
if '_filter_options' not in scene_struct: | |
precompute_filter_options(scene_struct, metadata) | |
attribute_map = {} | |
object_idxs = set(object_idxs) | |
for k, vs in scene_struct['_filter_options'].items(): | |
attribute_map[k] = sorted(list(object_idxs & vs)) | |
return attribute_map | |
return attribute_map | |
def add_empty_filter_options(attribute_map, metadata, num_to_add): | |
# Add some filtering criterion that do NOT correspond to objects | |
if metadata['dataset'] == 'shapes': | |
attr_keys = ['Size', 'Color', 'Material', 'Shape'] | |
elif metadata['dataset'] == 'visual-genome': | |
attr_keys = ['Color', 'Material', 'ObjectCategory'] | |
else: | |
assert False, 'Unrecognized dataset' | |
attr_vals = [metadata['types'][t] + [None] for t in attr_keys] | |
if '_filter_options' in metadata: | |
attr_vals = metadata['_filter_options'] | |
target_size = len(attribute_map) + num_to_add | |
while len(attribute_map) < target_size: | |
k = (random.choice(v) for v in attr_vals) | |
if k not in attribute_map: | |
attribute_map[k] = [] | |
def find_relate_filter_options(object_idx, scene_struct, metadata, | |
unique=False, include_zero=False, trivial_frac=0.1): | |
options = {} | |
if '_filter_options' not in scene_struct: | |
precompute_filter_options(scene_struct, metadata) | |
if '_all_relationships' not in scene_struct: | |
scene_struct['_all_relationships'] = qeng.compute_all_relationships(scene_struct) | |
# TODO: Right now this is only looking for nontrivial combinations; in some | |
# cases I may want to add trivial combinations, either where the intersection | |
# is empty or where the intersection is equal to the filtering output. | |
trivial_options = {} | |
for relationship in scene_struct['_all_relationships']: | |
related = set(scene_struct['_all_relationships'][relationship][object_idx]) | |
for filters, filtered in scene_struct['_filter_options'].items(): | |
intersection = related & filtered | |
trivial = (intersection == filtered) | |
if unique and len(intersection) != 1: continue | |
if not include_zero and len(intersection) == 0: continue | |
if trivial: | |
trivial_options[(relationship, filters)] = sorted(list(intersection)) | |
else: | |
options[(relationship, filters)] = sorted(list(intersection)) | |
N, f = len(options), trivial_frac | |
num_trivial = round(N * f / (1 - f)) | |
trivial_options = list(trivial_options.items()) | |
random.shuffle(trivial_options) | |
for k, v in trivial_options[:num_trivial]: | |
options[k] = v | |
return options | |
def node_shallow_copy(node): | |
new_node = { | |
'type': node['type'], | |
'inputs': node['inputs'], | |
} | |
if 'side_inputs' in node: | |
new_node['side_inputs'] = node['side_inputs'] | |
return new_node | |
def other_heuristic(text, param_vals): | |
""" | |
Post-processing heuristic to handle the word "other" | |
""" | |
if ' other ' not in text and ' another ' not in text: | |
return text | |
target_keys = { | |
'<Z>', '<C>', '<M>', '<S>', | |
'<Z2>', '<C2>', '<M2>', '<S2>', | |
} | |
if param_vals.keys() != target_keys: | |
return text | |
key_pairs = [ | |
('<Z>', '<Z2>'), | |
('<C>', '<C2>'), | |
('<M>', '<M2>'), | |
('<S>', '<S2>'), | |
] | |
remove_other = False | |
for k1, k2 in key_pairs: | |
v1 = param_vals.get(k1, None) | |
v2 = param_vals.get(k2, None) | |
if v1 != '' and v2 != '' and v1 != v2: | |
print('other has got to go! %s = %s but %s = %s' | |
% (k1, v1, k2, v2)) | |
remove_other = True | |
break | |
if remove_other: | |
if ' other ' in text: | |
text = text.replace(' other ', ' ') | |
if ' another ' in text: | |
text = text.replace(' another ', ' a ') | |
return text | |
def instantiate_templates_dfs(scene_struct, template, metadata, answer_counts, | |
max_instances=None, verbose=False): | |
param_name_to_type = {p['name']: p['type'] for p in template['params']} | |
initial_state = { | |
'nodes': [node_shallow_copy(template['nodes'][0])], | |
'vals': {}, | |
'input_map': {0: 0}, | |
'next_template_node': 1, | |
} | |
states = [initial_state] | |
final_states = [] | |
while states: | |
state = states.pop() | |
# Check to make sure the current state is valid | |
q = {'nodes': state['nodes']} | |
outputs = qeng.answer_question(q, metadata, scene_struct, all_outputs=True) | |
answer = outputs[-1] | |
if answer == '__INVALID__': continue | |
# Check to make sure constraints are satisfied for the current state | |
skip_state = False | |
for constraint in template['constraints']: | |
if constraint['type'] == 'NEQ': | |
p1, p2 = constraint['params'] | |
v1, v2 = state['vals'].get(p1), state['vals'].get(p2) | |
if v1 is not None and v2 is not None and v1 != v2: | |
if verbose: | |
print('skipping due to NEQ constraint') | |
print(constraint) | |
print(state['vals']) | |
skip_state = True | |
break | |
elif constraint['type'] == 'NULL': | |
p = constraint['params'][0] | |
p_type = param_name_to_type[p] | |
v = state['vals'].get(p) | |
if v is not None: | |
skip = False | |
if p_type == 'Shape' and v != 'thing': skip = True | |
if p_type != 'Shape' and v != '': skip = True | |
if skip: | |
if verbose: | |
print('skipping due to NULL constraint') | |
print(constraint) | |
print(state['vals']) | |
skip_state = True | |
break | |
elif constraint['type'] == 'OUT_NEQ': | |
i, j = constraint['params'] | |
i = state['input_map'].get(i, None) | |
j = state['input_map'].get(j, None) | |
if i is not None and j is not None and outputs[i] == outputs[j]: | |
if verbose: | |
print('skipping due to OUT_NEQ constraint') | |
print(outputs[i]) | |
print(outputs[j]) | |
skip_state = True | |
break | |
else: | |
assert False, 'Unrecognized constraint type "%s"' % constraint['type'] | |
if skip_state: | |
continue | |
# We have already checked to make sure the answer is valid, so if we have | |
# processed all the nodes in the template then the current state is a valid | |
# question, so add it if it passes our rejection sampling tests. | |
if state['next_template_node'] == len(template['nodes']): | |
# Use our rejection sampling heuristics to decide whether we should | |
# keep this template instantiation | |
cur_answer_count = answer_counts[answer] | |
answer_counts_sorted = sorted(answer_counts.values()) | |
median_count = answer_counts_sorted[len(answer_counts_sorted) // 2] | |
median_count = max(median_count, 5) | |
if cur_answer_count > 1.1 * answer_counts_sorted[-2]: | |
if verbose: print('skipping due to second count') | |
continue | |
if cur_answer_count > 5.0 * median_count: | |
if verbose: print('skipping due to median') | |
continue | |
# If the template contains a raw relate node then we need to check for | |
# degeneracy at the end | |
has_relate = any(n['type'] == 'relate' for n in template['nodes']) | |
if has_relate: | |
degen = qeng.is_degenerate(q, metadata, scene_struct, answer=answer, | |
verbose=verbose) | |
if degen: | |
continue | |
# if qeng.is_degenerate(q, metadata, scene_struct, answer=answer, | |
# verbose=verbose): | |
# if verbose: | |
# print('skipping due to degenerecy') | |
# print('scene') | |
# for o in scene_struct['objects']: | |
# print(o['size'], o['color'], o['material'], o['shape']) | |
# print() | |
# print('question') | |
# for i, n in enumerate(q['nodes']): | |
# name = n['type'] | |
# if 'side_inputs' in n: | |
# name = '%s[%s]' % (name, n['side_inputs'][0]) | |
# print(i, name, n['_output']) | |
# print() | |
# continue | |
answer_counts[answer] += 1 | |
state['answer'] = answer | |
final_states.append(state) | |
if max_instances is not None and len(final_states) == max_instances: | |
break | |
continue | |
# Otherwise fetch the next node from the template | |
# Make a shallow copy so cached _outputs don't leak ... this is very nasty | |
next_node = template['nodes'][state['next_template_node']] | |
next_node = node_shallow_copy(next_node) | |
special_nodes = { | |
'filter_unique', 'filter_count', 'filter_exist', 'filter', | |
'relate_filter', 'relate_filter_unique', 'relate_filter_count', | |
'relate_filter_exist', | |
} | |
if next_node['type'] in special_nodes: | |
if next_node['type'].startswith('relate_filter'): | |
unique = (next_node['type'] == 'relate_filter_unique') | |
include_zero = (next_node['type'] == 'relate_filter_count' | |
or next_node['type'] == 'relate_filter_exist') | |
filter_options = find_relate_filter_options(answer, scene_struct, metadata, | |
unique=unique, include_zero=include_zero) | |
else: | |
filter_options = find_filter_options(answer, scene_struct, metadata) | |
if next_node['type'] == 'filter': | |
# Remove null filter | |
filter_options.pop((None, None, None, None), None) | |
if next_node['type'] == 'filter_unique': | |
# Get rid of all filter options that don't result in a single object | |
filter_options = {k: v for k, v in filter_options.items() | |
if len(v) == 1} | |
else: | |
# Add some filter options that do NOT correspond to the scene | |
if next_node['type'] == 'filter_exist': | |
# For filter_exist we want an equal number that do and don't | |
num_to_add = len(filter_options) | |
elif next_node['type'] == 'filter_count' or next_node['type'] == 'filter': | |
# For filter_count add nulls equal to the number of singletons | |
num_to_add = sum(1 for k, v in filter_options.items() if len(v) == 1) | |
add_empty_filter_options(filter_options, metadata, num_to_add) | |
filter_option_keys = list(filter_options.keys()) | |
random.shuffle(filter_option_keys) | |
for k in filter_option_keys: | |
new_nodes = [] | |
cur_next_vals = {k: v for k, v in state['vals'].items()} | |
next_input = state['input_map'][next_node['inputs'][0]] | |
filter_side_inputs = next_node['side_inputs'] | |
if next_node['type'].startswith('relate'): | |
param_name = next_node['side_inputs'][0] # First one should be relate | |
filter_side_inputs = next_node['side_inputs'][1:] | |
param_type = param_name_to_type[param_name] | |
assert param_type == 'Relation' | |
param_val = k[0] | |
k = k[1] | |
new_nodes.append({ | |
'type': 'relate', | |
'inputs': [next_input], | |
'side_inputs': [param_val], | |
}) | |
cur_next_vals[param_name] = param_val | |
next_input = len(state['nodes']) + len(new_nodes) - 1 | |
for param_name, param_val in zip(filter_side_inputs, k): | |
param_type = param_name_to_type[param_name] | |
filter_type = 'filter_%s' % param_type.lower() | |
if param_val is not None: | |
new_nodes.append({ | |
'type': filter_type, | |
'inputs': [next_input], | |
'side_inputs': [param_val], | |
}) | |
cur_next_vals[param_name] = param_val | |
next_input = len(state['nodes']) + len(new_nodes) - 1 | |
elif param_val is None: | |
if metadata['dataset'] == 'shapes' and param_type == 'Shape': | |
param_val = 'thing' | |
elif metadata['dataset'] == 'visual-genome' and param_type == 'ObjectCategory': | |
param_val = 'thing' | |
else: | |
param_val = '' | |
cur_next_vals[param_name] = param_val | |
input_map = {k: v for k, v in state['input_map'].items()} | |
extra_type = None | |
if next_node['type'].endswith('unique'): | |
extra_type = 'unique' | |
if next_node['type'].endswith('count'): | |
extra_type = 'count' | |
if next_node['type'].endswith('exist'): | |
extra_type = 'exist' | |
if extra_type is not None: | |
# if next_node['type'] != 'filter': | |
new_nodes.append({ | |
# 'type': next_node['type'].split('_')[1], | |
'type': extra_type, | |
'inputs': [input_map[next_node['inputs'][0]] + len(new_nodes)], | |
}) | |
input_map[state['next_template_node']] = len(state['nodes']) + len(new_nodes) - 1 | |
states.append({ | |
'nodes': state['nodes'] + new_nodes, | |
'vals': cur_next_vals, | |
'input_map': input_map, | |
'next_template_node': state['next_template_node'] + 1, | |
}) | |
elif 'side_inputs' in next_node: | |
# If the next node has template parameters, expand them out | |
# TODO: Generalize this to work for nodes with more than one side input | |
assert len(next_node['side_inputs']) == 1, 'NOT IMPLEMENTED' | |
# Use metadata to figure out domain of valid values for this parameter. | |
# Iterate over the values in a random order; then it is safe to bail | |
# from the DFS as soon as we find the desired number of valid template | |
# instantiations. | |
param_name = next_node['side_inputs'][0] | |
param_type = param_name_to_type[param_name] | |
param_vals = metadata['types'][param_type][:] | |
random.shuffle(param_vals) | |
for val in param_vals: | |
input_map = {k: v for k, v in state['input_map'].items()} | |
input_map[state['next_template_node']] = len(state['nodes']) | |
cur_next_node = { | |
'type': next_node['type'], | |
'inputs': [input_map[idx] for idx in next_node['inputs']], | |
'side_inputs': [val], | |
} | |
cur_next_vals = {k: v for k, v in state['vals'].items()} | |
cur_next_vals[param_name] = val | |
states.append({ | |
'nodes': state['nodes'] + [cur_next_node], | |
'vals': cur_next_vals, | |
'input_map': input_map, | |
'next_template_node': state['next_template_node'] + 1, | |
}) | |
else: | |
input_map = {k: v for k, v in state['input_map'].items()} | |
input_map[state['next_template_node']] = len(state['nodes']) | |
next_node = { | |
'type': next_node['type'], | |
'inputs': [input_map[idx] for idx in next_node['inputs']], | |
} | |
states.append({ | |
'nodes': state['nodes'] + [next_node], | |
'vals': state['vals'], | |
'input_map': input_map, | |
'next_template_node': state['next_template_node'] + 1, | |
}) | |
if metadata['dataset'] == 'shapes': | |
synonyms = { | |
'thing': ['thing', 'object'], | |
'sphere': ['sphere', 'ball'], | |
'cube': ['cube', 'block'], | |
'large': ['large', 'big'], | |
'small': ['small', 'tiny'], | |
'metal': ['metallic', 'metal', 'shiny'], | |
'rubber': ['rubber', 'matte'], | |
'left': ['left of', 'to the left of', 'on the left side of'], | |
'right': ['right of', 'to the right of', 'on the right side of'], | |
'behind': ['behind'], | |
'front': ['in front of'], | |
'above': ['above'], | |
'below': ['below'], | |
} | |
elif metadata['dataset'] == 'visual-genome': | |
synonyms = {r: r for r in metadata['types']['Relation']} | |
text_questions, structured_questions, answers = [], [], [] | |
for state in final_states: | |
structured_questions.append(state['nodes']) | |
answers.append(state['answer']) | |
text = random.choice(template['text']) | |
for name, val in state['vals'].items(): | |
if val in synonyms: val = random.choice(synonyms[val]) | |
text = text.replace(name, val) | |
text = ' '.join(text.split()) | |
text = replace_optionals(text) | |
text = ' '.join(text.split()) | |
text = other_heuristic(text, state['vals']) | |
text_questions.append(text) | |
return text_questions, structured_questions, answers | |
def replace_optionals(s): | |
""" | |
Each substring of s that is surrounded in square brackets is treated as | |
optional and is removed with probability 0.5. For example the string | |
"A [aa] B [bb]" | |
could become any of | |
"A aa B bb" | |
"A B bb" | |
"A aa B " | |
"A B " | |
with probability 1/4. | |
""" | |
pat = re.compile(r'\[([^\[]*)\]') | |
while True: | |
match = re.search(pat, s) | |
if not match: | |
break | |
i0 = match.start() | |
i1 = match.end() | |
if random.random() > 0.5: | |
s = s[:i0] + match.groups()[0] + s[i1:] | |
else: | |
s = s[:i0] + s[i1:] | |
return s | |
def main(args): | |
with open(args.metadata_file, 'r') as f: | |
metadata = json.load(f) | |
if metadata['dataset'] == 'visual-genome': | |
with open(args.filter_options_file, 'r') as f: | |
metadata['_filter_options'] = json.load(f) | |
functions_by_name = {} | |
for f in metadata['functions']: | |
functions_by_name[f['name']] = f | |
metadata['_functions_by_name'] = functions_by_name | |
# Key is (filename, file_idx, expansion_idx) | |
num_loaded_templates = 0 | |
num_expanded_templates = 0 | |
templates = {} | |
for fn in os.listdir(args.template_dir): | |
if not fn.endswith('.json'): continue | |
with open(os.path.join(args.template_dir, fn), 'r') as f: | |
base = os.path.splitext(fn)[0] | |
for i, template in enumerate(json.load(f)): | |
num_loaded_templates += 1 | |
# expanded = expand_template_options(template) | |
expanded = [template] | |
num_expanded_templates += len(expanded) | |
for j, t in enumerate(expanded): | |
key = (fn, i, j) | |
templates[key] = t | |
print('Read %d templates from disk' % num_loaded_templates) | |
print('After expansion there are %d templates' % num_expanded_templates) | |
def reset_counts(): | |
# Maps a template (filename, index) to the number of questions we have | |
# so far using that template | |
template_counts = {} | |
# Maps a template (filename, index) to a dict mapping the answer to the | |
# number of questions so far of that template type with that answer | |
template_answer_counts = {} | |
node_type_to_dtype = {n['name']: n['output'] for n in metadata['functions']} | |
for key, template in templates.items(): | |
template_counts[key[:2]] = 0 | |
final_node_type = template['nodes'][-1]['type'] | |
final_dtype = node_type_to_dtype[final_node_type] | |
answers = metadata['types'][final_dtype] | |
if final_dtype == 'Bool': | |
answers = [True, False] | |
if final_dtype == 'Integer': | |
if metadata['dataset'] == 'shapes': | |
answers = list(range(0, 11)) | |
elif metadata['dataset'] == 'visual-genome': | |
answers = list(range(0, 150)) | |
template_answer_counts[key[:2]] = {} | |
for a in answers: | |
template_answer_counts[key[:2]][a] = 0 | |
return template_counts, template_answer_counts | |
template_counts, template_answer_counts = reset_counts() | |
assert (args.scene_dir == '') != (args.scene_file == '') | |
all_scenes = [] | |
if args.scene_dir != '': | |
for fn in os.listdir(args.scene_dir): | |
if not fn.endswith('.json'): | |
continue | |
with open(os.path.join(args.scene_dir, fn), 'r') as f: | |
all_scenes.append({ | |
'filename': fn, | |
'scene': json.load(f), | |
}) | |
if args.num_scenes > 0: | |
all_scenes = all_scenes[:args.num_scenes] | |
elif args.scene_file != '': | |
if args.use_local_copies == 0: | |
scene_file = args.scene_file | |
else: | |
print('copying scene file to local directory...') | |
shutil.copy(args.scene_file, '_scene_file.json') | |
scene_file = '_scene_file.json' | |
print('done') | |
print('reading scene data from ', scene_file) | |
with open(scene_file, 'r') as f: | |
all_scenes = json.load(f) | |
print('done') | |
begin = args.scene_start_idx | |
if args.num_scenes > 0: | |
end = args.scene_start_idx + args.num_scenes | |
all_scenes = all_scenes[begin:end] | |
else: | |
all_scenes = all_scenes[begin:] | |
questions = [] | |
scene_count = 0 | |
for i, scene in enumerate(all_scenes): | |
scene_fn = scene['filename'] | |
scene_struct = scene['scene'] | |
print('starting image %s (%d / %d)' | |
% (scene_fn, i + 1, len(all_scenes))) | |
print('number of objects: ', len(scene_struct['objects'])) | |
if scene_count % args.reset_counts_every == 0: | |
print('resetting counts') | |
template_counts, template_answer_counts = reset_counts() | |
scene_count += 1 | |
# Order templates by the number of questions we have so far for those | |
# templates. This is a simple heuristic to give a flat distribution over | |
# templates. | |
templates_items = list(templates.items()) | |
templates_items = sorted(templates_items, | |
key=lambda x: template_counts[x[0][:2]]) | |
num_instantiated = 0 | |
for (fn, idx, jdx), template in templates_items: | |
# print('starting template ', fn, idx) | |
print('trying template ', fn, idx) | |
# verbose = fn == 'double_relate.json' and idx == 1 | |
verbose = False | |
tic = time.time() | |
ts, qs, ans = instantiate_templates_dfs( | |
scene_struct, | |
template, | |
metadata, | |
template_answer_counts[(fn, idx)], | |
max_instances=args.instances_per_template, | |
verbose=verbose) | |
toc = time.time() | |
print('that took ', toc - tic) | |
for t, q, a in zip(ts, qs, ans): | |
questions.append({ | |
'image': os.path.splitext(scene_fn)[0], | |
'text_question': t, | |
'structured_question': q, | |
'answer': a, | |
'template_filename': fn, | |
'template_idx': idx, | |
'template_expand_idx': jdx, | |
'question_id': len(questions), | |
}) | |
if len(ts) > 0: | |
print('got one!') | |
num_instantiated += 1 | |
template_counts[(fn, idx)] += 1 | |
else: | |
print('did not get any =(') | |
if num_instantiated >= args.templates_per_image: | |
break | |
with open(args.output_json, 'w') as f: | |
json.dump(questions, f) | |
if __name__ == '__main__': | |
if args.profile == 1: | |
import cProfile | |
cProfile.run('main(args)') | |
else: | |
main(args) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment