Skip to content

Instantly share code, notes, and snippets.

@Slater-Victoroff
Created January 9, 2020 14:34
Show Gist options
  • Save Slater-Victoroff/f89350b80ef80fc2c733b6167a8115cf to your computer and use it in GitHub Desktop.
Save Slater-Victoroff/f89350b80ef80fc2c733b6167a8115cf to your computer and use it in GitHub Desktop.
"""
Collection of methods built to assist in data augmentation for extraction datasets
"""
from ast import literal_eval
import json
import random
from collections import defaultdict
from functools import partial
from typing import Iterable, Dict, Callable
import pandas as pd
from BlueJet.sources.local import TrainingSet
def random_swap(text, options):
# Can't swap if you have nothing to swap with
# so fallback to our only valid option
options = tuple(options - {text}) or (text,)
return random.choice(options)
def _get_swap_lambdas(df, source_col, target_col, **kwargs):
options = defaultdict(set)
for _, row in df.iterrows():
for label in row[target_col]:
if not label.get("text"):
label["text"] = row[source_col][label["start"]:label["end"]]
options[label["label"]] |= {label["text"]}
# Yes, we are ignoring x on purpose
return {key: partial(random_swap, options=options[key]) for key in options}
def _get_noop(df, source_col, target_col, **kwargs):
options = set(label["label"] for labels in df[target_col] for label in labels)
return {key: lambda x : x for key in options}
class TokenAugmentor:
"""
Only applies for extraction models.
TokenAugmentor only replaces the tagged tokens and updates labels. Context augmentation
will be handled in a separate Augmentor.
Example:
df = pd.read_csv("<source>.csv")
test = TokenAugmentor({"swap": 3, "original": 1})
test.augment(df, "text", "question_#", "<destination>.csv")
"""
strategies = {
"swap": _get_swap_lambdas,
"original": _get_noop
}
def __init__(self, strategy:dict):
for key in strategy:
if key.lower() not in self.strategies:
raise ValueError("%s not an available strategy" % key)
self.strategy = strategy
def augment(self, df, source_col:str, target_col:str, results_file:str, **kwargs):
df[target_col] = df[target_col].apply(literal_eval)
swappers = {}
for key in self.strategy:
swappers[key] = self.strategies[key](df, source_col, target_col, **kwargs)
augmented_data = {source_col: [], target_col: []}
for _, row in df.iterrows():
new_sources, new_targets = self._augment_docs(row[source_col], row[target_col], swappers)
augmented_data[source_col].extend(new_sources)
augmented_data[target_col].extend(new_targets)
df = pd.DataFrame(augmented_data)
df.to_csv(open(results_file, "w"))
def _augment_docs(self, source:str, target:Iterable[dict], swappers:Dict[str, Callable], ):
new_sources = []
new_targets = []
for key, value in self.strategy.items():
for _ in range(value):
try:
new_source, new_target = self._augment_doc(source, target, swappers[key])
new_sources.append(new_source)
new_targets.append(json.dumps(new_target))
except NotImplementedError:
continue
return new_sources, new_targets
def _augment_doc(self, source:str, target:Iterable[dict], swapper:Dict[str, Callable]):
offset = 0
new_source = source
new_target = []
last_end = 0
for entry in sorted(target, key=lambda x: x["start"]):
original_value = entry["text"]
new_value = swapper[entry["label"]](entry["text"])
# Checking for overlap
if new_target and (last_end > entry["start"]):
raise NotImplementedError(
"Overlapping labels are not yet supported"
)
last_end = entry["end"]
new_start = entry["start"] + offset
new_end = entry["end"] + offset
new_source = new_source[:new_start] + new_value + new_source[new_end:]
offset += (len(new_value) - len(original_value))
final_end = entry["end"] + offset
new_target.append({
"start": new_start,
"end": final_end,
"label": entry["label"],
"text": new_value
})
# Ensure that mapping was accomplished successfully. Else error.
for new_entry in new_target:
assert(new_source[new_entry["start"]: new_entry["end"]] == new_entry["text"])
return new_source, new_target
def convert_format(source_file:str, source_col:str, target_col:str):
"""
Convert files from standard Teach Export format to standard BlueJet format
"""
def rewrite_labels(row):
text = row[source_col]
old_labels = json.loads(row[target_col])
new_labels = []
for old_label in old_labels:
extracted_text = text[old_label["startOffset"]: old_label["endOffset"]]
new_label = {
"label": old_label["label"],
"start": old_label["startOffset"],
"end": old_label["endOffset"],
"text": extracted_text
}
new_labels.append(new_label)
return new_labels
df = pd.read_csv(open(source_file))
new_label_col = []
for _, row in df.iterrows():
new_label_col.append(rewrite_labels(row))
df[target_col] = new_label_col
df.to_csv(open(source_file, "w"))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment