Skip to content

Instantly share code, notes, and snippets.

@hans
Last active May 9, 2022 15:37
Show Gist options
  • Save hans/a7fd2ea8cd53a5576fcb9acb9140412a to your computer and use it in GitHub Desktop.
Save hans/a7fd2ea8cd53a5576fcb9acb9140412a to your computer and use it in GitHub Desktop.
import json
from pathlib import Path
from syntaxgym import Suite
import transformers
# Materials from https://github.com/cpllab/syntactic-generalization
ds = []
for p in Path("../syntactic-generalization/test_suites/json").glob("*.json"):
with open(p) as f:
ds.append(json.load(f))
ss = [Suite.from_dict(d) for d in ds]
out = []
for s in ss:
for condition, region_number in s.predictions[0].referenced_regions:
for item in s.items:
c = next(cx for cx in item["conditions"] if cx["condition_name"] == condition)
out.append((s.meta["name"], item["item_number"], region_number, c["regions"][region_number - 1]["content"]))
df = pd.DataFrame(out, columns=["suite", "item", "region_number", "content"])
#########
tk = transformers.AutoTokenizer.from_pretrained("gpt2")
decoded = [tk.convert_ids_to_tokens(tk.encode(string)) for string in df.content]
##########
# We only care about items where critical region content differs by condition.
df["matched_content"] = df.groupby(["suite", "item", "region_number"]).content.transform(lambda xs: len(set(xs)) == 1)
df["content_tokenized"] = ["_".join(content) for content in decoded]
df["num_tokens_bpe"] = [len(content) for content in decoded]
df["num_tokens_whitespace"] = df.content.str.count(" ") + 1
df["num_bpe_splits"] = df.num_tokens_bpe - df.num_tokens_whitespace
df.to_csv("critical_region_analysis.csv")
# avg. number of BPE splits per critical region for relevant items/suites, grouped by suite
df[~df.matched_content].groupby("suite").num_bpe_splits.mean()
# on the item level -- how many items with unmatched critical region content have content that contain BPE splits
(df[~df.matched_content].groupby(["suite", "item"]).num_bpe_splits.max() > 0).agg(["sum", "mean"])
suite
center_embed 1.000000
center_embed_mod 1.000000
number_orc 0.368421
number_prep 0.368421
number_src 0.368421
reflexive_orc_fem 1.000000
reflexive_orc_masc 1.000000
reflexive_prep_fem 1.000000
reflexive_prep_masc 1.000000
reflexive_src_fem 1.000000
reflexive_src_masc 1.000000
subordination 0.021739
subordination_orc-orc 0.021739
subordination_pp-pp 0.021739
subordination_src-src 0.021739
Name: num_bpe_splits, dtype: float64
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment