Skip to content

Instantly share code, notes, and snippets.

@jayelm
Created August 2, 2019 02:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jayelm/8918f0bb6580dcc9af81f2f842784dd5 to your computer and use it in GitHub Desktop.
Save jayelm/8918f0bb6580dcc9af81f2f842784dd5 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
# Use this commit:
# 82b3c3a106a5d5f6d8afe98c34b301e3ed696865
# https://github.com/AlexKuhnle/ShapeWorld/tree/82b3c3a106a5d5f6d8afe98c34b301e3ed696865
from collections import defaultdict
import numpy as np
import os
from shapeworld import dataset
import json
import itertools
from tqdm import tqdm, trange
from shapeworld.world.shape import Shape
from shapeworld.world.color import Color
# From generators/generator.py
SHADE_RANGE = 0.33
SIZE_RANGE = (0.1, 0.2)
DISTORTION_RANGE = (2.0, 3.0)
from PIL import Image
import os
def save_world(world, filename):
worldarr = world.get_array()
worldarr = np.round_(worldarr * 255).astype(np.uint8)
img = Image.fromarray(worldarr)
img.save(filename)
def negate_world(world, save=False):
if save:
save_world(world, 'test_before.png')
entity_names = [str(e) for e in world.entities]
modified = False
for i, entity_name in enumerate(entity_names):
if entity_name in IGNORE_ENTITIES:
modified = True
# Change this entity.
entity = world.entities[i]
if random.randint(2):
# Modify shape
new_shape = None
while new_shape is None or (new_shape.name == entity.shape.name):
new_shape = Shape.random_instance(Shape.shapes, SIZE_RANGE, DISTORTION_RANGE)
entity.shape = new_shape
else:
# Modify color
new_color = None
while new_color is None or (new_color.name == entity.color.name):
new_color = Color.random_instance(Color.colors, SHADE_RANGE)
entity.color = new_color
if save:
save_world(world, 'test_after.png')
if not modified:
raise RuntimeError("I didn't modify this world: {}".format(entity_names))
N_TRAIN = 2000
N_VAL = 500
N_TEST = 300
# N_TRAIN = 50
# N_VAL = 25
# N_TEST = 10
N_CAPTIONS = N_TRAIN + N_VAL + N_TEST
WIDTH = 64
HEIGHT = 64
CHANNELS = 3
EXAMPLES = 4
DATASET = dataset(dtype="agreement", name="spatial_jda")
random = np.random.RandomState(0)
IGNORE_ENTITIES_RAW = [
('square', 'red', 'solid'),
('rectangle', 'green', 'solid'),
('triangle', 'blue', 'solid'),
('pentagon', 'yellow', 'solid'),
('cross', 'magenta', 'solid'),
('circle', 'cyan', 'solid'),
('semicircle', 'white', 'solid'),
]
IGNORE_CAPTION_NAMES = ['{} {}'.format(e[1], e[0]) for e in IGNORE_ENTITIES_RAW]
IGNORE_ENTITIES = [str(e).replace("\'", '') for e in IGNORE_ENTITIES_RAW]
caption_data = {}
test_caption_data = {}
trainval_pbar = tqdm(total=N_TRAIN + N_VAL, desc='Train/val captions')
test_pbar = tqdm(total=N_TEST, desc='Test captions')
# Sample randomly for train/val, skipping forbiddens
while (len(caption_data) < (N_TRAIN + N_VAL)) or (len(test_caption_data) < N_TEST):
if len(caption_data) > (N_TRAIN + N_VAL):
trainval_pbar.set_description('Train/val captions (extra)')
DATASET.world_generator.sample_values(mode="train")
DATASET.world_captioner.sample_values(mode="train", correct=True)
while True:
world = DATASET.world_generator()
if world is None:
continue
caption = DATASET.world_captioner(entities=world.entities)
if caption is None:
continue
break
realized, = DATASET.caption_realizer.realize(captions=[caption])
realized_str = ' '.join(realized)
if any(i in realized_str for i in IGNORE_CAPTION_NAMES):
realized = tuple(realized)
if realized not in test_caption_data:
test_pbar.update(1)
test_caption_data[realized] = caption
else:
realized = tuple(realized)
if realized not in caption_data:
trainval_pbar.update(1)
caption_data[realized] = caption
trainval_pbar.close()
test_pbar.close()
# Compositional split: leave out "red triangles" at train time. Test time is
# done with red triangles. Make sure generate makes sense, and inspect the
# results!
# How to swap shapes
# import ipdb; ipdb.set_trace()
# print([str(x) for x in world.entities])
# save_world(world, 'test.png')
# # Modify entities - modify shape of red triangle
# center0 = world.entities[0].center
# world.entities[0] = world.entities[1].copy()
# world.entities[0].set_center(center0)
# save_world(world, 'test2.png')
# PROBLEMS:
# (1) We want to test systematic generalization which means we need "tricky"
# negative examples, e.g. examples where the red triangle is a red shape but
# not a triangle, or a triangle but not a red shape. Can we modify the world to
# make that happen?
# (2) no way to verify that captions are "new" in that they haven't appeared in
# some hidden form in the train set (specifically, consider above vs below/left
# vs right, I think that's the only issue). This is not a major issue for the
# compositional split since we guarantee that entities haven't appear.ed
captions = list(sorted(caption_data.keys()))
random.shuffle(captions)
train_captions = captions[:N_TRAIN]
val_captions = captions[N_TRAIN:N_TRAIN+N_VAL]
test_captions = list(sorted(test_caption_data.keys()))
random.shuffle(test_captions)
# Combine
caption_data.update(test_caption_data)
def has_ignore_entity(world, ignore_entities):
"""
Return True if the world has any ignore entities
"""
world_entities = [str(e) for e in world.entities]
return any(i in world_entities for i in ignore_entities)
def generate(name, captions, n_examples, ignore_entities=None, hard_negatives=False,
save=False, save_hard_negatives=None, examples_ratio=20):
mappings = defaultdict(list)
max_scenes = n_examples * examples_ratio
# Here, generate many many scenes. For each scene, check if it agrees with
# any of the captions - if so, add. Brute force approach: generate as much
# as you can for captions, so you can get positive examples. Negative
# examples are easier since
total_scenes = 0
pbar = tqdm(total=max_scenes, desc='{} scenes'.format(name))
while total_scenes < max_scenes:
DATASET.world_generator.sample_values(mode="train")
world = DATASET.world_generator()
if world is None:
continue
if ignore_entities is not None and has_ignore_entity(world, ignore_entities):
# Discard world as it contains a forbidden entity
continue
for key in captions:
# Add this image to whichever captions align. Note this means an
# image can appear multiple times in a single split
# (train/val/test) aligned with different concepts, but no images
# will be shared between train/val/test
caption = caption_data[key]
agree = caption.agreement(entities=world.entities) > 0
if not agree:
continue
mappings[key].append(world)
total_scenes += 1
pbar.update(1)
pbar.close()
for key in mappings:
print(" ".join(key), len(mappings[key]))
if save:
for key, worlds in mappings.items():
key_dirname = os.path.join('vis', '_'.join(key)[:-2])
os.makedirs(key_dirname, exist_ok=True)
for i, world in enumerate(worlds[:5]):
worldname = os.path.join(key_dirname, '{}.png'.format(i))
save_world(world, worldname)
examples = np.zeros((n_examples, EXAMPLES, WIDTH, HEIGHT, CHANNELS))
inputs = np.zeros((n_examples, WIDTH, HEIGHT, CHANNELS))
labels = np.zeros((n_examples,), dtype=np.uint8)
hints = []
test_hints = []
i_example = 0
pbar = tqdm(total=n_examples, desc='{} examples'.format(name))
while i_example < n_examples:
key = captions[random.randint(len(captions))]
worlds = mappings[key]
if len(worlds) < EXAMPLES + 1:
continue
if save_hard_negatives is not None and i_example < save_hard_negatives:
i_dir = os.path.join('vis', '_'.join(key)[:-2])
os.makedirs(i_dir, exist_ok=True)
for i_world in range(EXAMPLES):
world = worlds.pop()
if save_hard_negatives is not None and i_example < save_hard_negatives:
save_world(world, os.path.join(i_dir, 'train_{}.png'.format(i_world)))
examples[i_example, i_world, ...] = world.get_array()
if random.randint(2) == 0:
# Positive example: sample from this class.
world = worlds.pop()
inputs[i_example, ...] = world.get_array()
labels[i_example] = 1
test_hint = key
if save_hard_negatives is not None and i_example < save_hard_negatives:
save_world(world, os.path.join(i_dir, 'test_pos.png'))
else:
# Negative example: tweak the entity in question. hard_negatives
# is True for test captions. figure out which entity has not been
# seen (or both). Of those entities, permute either the color or
# the shape (this is how we systematically test compositionality).
if hard_negatives:
assert name == 'test'
world = worlds.pop()
# Modify world
negate_world(world)
inputs[i_example, ...] = world.get_array()
labels[i_example] = 0
# No clue what the test hint is after negation.
test_hint = ('not', ) + key
if save_hard_negatives is not None and i_example < save_hard_negatives:
save_world(world, os.path.join(i_dir, 'test_neg.png'))
else:
# Sample randomly from another caption. Note this does NOT
# guarantee that the negative example does not belong to the
# caption!
while True:
# Try a different caption
other_key = captions[random.randint(len(captions))]
# If there are worlds available for this caption, get it
if len(mappings[other_key]) > 0:
other_world = mappings[other_key].pop()
break
# Set this as a negative example
inputs[i_example, ...] = other_world.get_array()
labels[i_example] = 0
# Set the test hint's key
test_hint = other_key
hints.append(" ".join(key))
test_hints.append(" ".join(test_hint))
i_example += 1
pbar.update(1)
pbar.close()
print("\n\n")
os.makedirs(name, exist_ok=True)
np.save(os.path.join(name, "examples.npy"), examples)
np.save(os.path.join(name, "inputs.npy"), inputs)
np.save(os.path.join(name, "labels.npy"), labels)
with open(os.path.join(name, "hints.json"), "w") as hint_f:
json.dump(hints, hint_f)
with open(os.path.join(name, "test_hints.json"), "w") as t_hint_f:
json.dump(test_hints, t_hint_f)
generate("train", train_captions, 9000, ignore_entities=IGNORE_ENTITIES)
generate("val", val_captions, 1000, ignore_entities=IGNORE_ENTITIES)
generate("test", test_captions, 1000, hard_negatives=True, save_hard_negatives=100)
# generate("val_same", train_captions, 500)
# generate("test_same", train_captions, 500)
# generate("train", train_captions, 50, ignore_entities=IGNORE_ENTITIES)
# generate("val", val_captions, 50, ignore_entities=IGNORE_ENTITIES)
# generate("test", test_captions, 10, hard_negatives=True, save_hard_negatives=5)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment