-
-
Save mkshing/d6371cbfdd50d4f352cee247fd4dd86a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import itertools | |
from collections import Counter | |
import re | |
import emoji | |
import neologdn | |
from fugashi import Tagger | |
from nltk.util import ngrams | |
from datasets import load_dataset | |
from tqdm import tqdm | |
tagger = Tagger('-Owakati') | |
def to_tuple(lst): | |
return [tuple(l) for l in lst] | |
def normalize_answer(s): | |
"""Lower text and remove punctuation, articles and extra whitespace.""" | |
def white_space_fix(text): | |
return " ".join(text.split()) | |
def remove_emoji(text): | |
text = "".join(["" if emoji.is_emoji(c) else c for c in text]) | |
emoji_pattern = re.compile( | |
"[" | |
u"\U0001F600-\U0001F64F" # emoticons | |
u"\U0001F300-\U0001F5FF" # symbols & pictographs | |
u"\U0001F680-\U0001F6FF" # transport & map symbols | |
u"\U0001F1E0-\U0001F1FF" # flags (iOS) | |
u"\U00002702-\U000027B0" | |
"]+", | |
flags=re.UNICODE, | |
) | |
return emoji_pattern.sub(r"", text) | |
return white_space_fix((neologdn.normalize(remove_emoji(s)))) | |
def mecab_tokenizer(text): | |
return tagger.parse(normalize_answer(text)).split() | |
def to_15gram(text): | |
return list(ngrams(mecab_tokenizer(text), 15)) | |
def tokenize_function(examples): | |
return { | |
# "title": to_15gram(examples["title"]), | |
"summary": to_15gram(examples["summary"]), | |
"text": to_15gram(examples["text"]), | |
} | |
def main(split="test"): | |
print(f"Processing split={split}") | |
dataset = load_dataset("csebuetnlp/xlsum", "japanese", split=split) | |
tokenized_dataset = dataset.map(tokenize_function) | |
all_ids = [] | |
add_ids = [] | |
delete_ids = [] | |
for key in ["summary", "text"]: | |
texts = [row[key] for row in tokenized_dataset] | |
all_texts = list(itertools.chain(*texts)) | |
res = Counter(to_tuple(all_texts)) | |
morethan1 = {k: v for k, v in res.items() if v > 1} | |
# just take first example | |
for match in tqdm(morethan1.keys()): | |
found = [idx for idx, text in enumerate(texts) if list(match) in text] | |
id_ = [i for i in found if i in add_ids] | |
if len(id_) == 0: | |
id_ = found[0] | |
add_ids.append(id_) | |
else: | |
id_ = id_[0] | |
all_ids += [i for i in found if i not in all_ids] | |
delete_ids += [i for i in found if id_ != i and i not in delete_ids] | |
# if ids in add_ids are also in delete_ids, just delete them | |
match = list(set(add_ids) & set(delete_ids)) | |
add_ids = [i for i in add_ids if i not in match] # exlude ids | |
assert len(all_ids) == len(add_ids)+len(delete_ids) | |
assert len(set(add_ids) & set(delete_ids)) == 0 | |
df = dataset.to_pandas() | |
df_res = df.drop(delete_ids) | |
print(f"Deleted {len(delete_ids)}\nBefore: {len(df)} --> After: {len(df_res)}") | |
df_res.to_csv(f"{split}.csv", index=False) | |
if __name__ == "__main__": | |
# main("train") | |
main("validation") | |
main("test") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment