Skip to content

Instantly share code, notes, and snippets.

@danyaljj
Created November 14, 2020 01:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save danyaljj/30ca8b9cf088ea0cc5273eccc6b88552 to your computer and use it in GitHub Desktop.
Save danyaljj/30ca8b9cf088ea0cc5273eccc6b88552 to your computer and use it in GitHub Desktop.
import random
threshold = 0.50
with open("/Users/danielk/ideaProjects/parsiglue-baselines/data/qqp/QQP-all.tsv") as f:
all_lines = list(f.readlines())
all_sentence_pairs = []
all_splits = []
for line in all_lines:
line_split = line.replace("\n", "").split("\t")
q1= line_split[-4]
q2 = line_split[-3]
all_sentence_pairs.append(
q1.split(" ") + q2.split(" ")
)
if random.random() < 0.33:
all_splits.append("dev")
else:
all_splits.append("train")
for i, sentence in enumerate(all_sentence_pairs):
if random.random() < 0.18 and all_splits[i] != "test":
all_splits[i] = "test"
print(f"----------\n {' '.join(sentence)}")
for i2, sentence2 in enumerate(all_sentence_pairs):
if i == i2:
continue
intersection = len(list(set(sentence2) & set(sentence)))
o1 = intersection / len(sentence2)
o2 = intersection / len(sentence)
if o1 >= threshold or o2 >= threshold:
print(f" -> {' '.join(sentence2)} ({o1}, {o2})")
all_splits[i2] = "test"
test_count = len([x for x in all_splits if x == "test"])
train_count = len([x for x in all_splits if x == "train"])
dev_count = len([x for x in all_splits if x == "dev"])
print(f" -> test count: {test_count}")
print(f" -> train count: {train_count}")
print(f" -> dev count: {dev_count}")
outfile = open("/Users/danielk/ideaProjects/parsiglue-baselines/data/qqp/QQP-all-with-splits.tsv", "+w")
for line, split in zip(all_lines, all_splits):
line = line.replace("\n", "")
line = line + "\t" + split + "\n"
outfile.write(line)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment