Skip to content

Instantly share code, notes, and snippets.

@dchaplinsky
Last active April 12, 2023 16:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dchaplinsky/3e0e3673d568a28ef38fb0848b90ebb5 to your computer and use it in GitHub Desktop.
Save dchaplinsky/3e0e3673d568a28ef38fb0848b90ebb5 to your computer and use it in GitHub Desktop.
A set of scripts to export and deduplicate data from different ukrainian corpora for the GPT-x tuning
import json
import argparse
from typing import Dict
from pathlib import Path
import smart_open
import ftfy
from tqdm import tqdm
import html2text
from datasets import load_dataset
h = html2text.HTML2Text()
h.ignore_links = True
h.ignore_images = True
h.used = 0
def remove_tags(s: str) -> str:
"""
Turn html into markdown format
"""
global h
if h.used > 1000:
h = html2text.HTML2Text()
h.ignore_links = True
h.ignore_images = True
h.used = 0
else:
h.used += 1
return h.handle(s).strip()
def process_doc(doc: Dict) -> str:
"""
Render doc with into the jsonl format suitable for Volodymyr
:param doc: doc dict from the dataset
:return:
"""
return {
"_id": str(doc.get("id")),
"text": ftfy.fix_text(remove_tags(doc.get("text", "") or "")),
"title": ftfy.fix_text(doc.get("title", "") or ""),
"date_of_publish": doc.get("datetime", ""),
"tags": [ftfy.fix_text(doc.get("owner", "") or "")],
}
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Export news dataset in the format requested by Volodymyr"
)
parser.add_argument("output_file", help="path to input JSONL file", type=Path)
args = parser.parse_args()
dataset = load_dataset("zeusfsx/ukrainian-news", split="train", streaming=True)
with smart_open.open(args.output_file, "wt", encoding="utf-8") as writer:
for doc in tqdm(dataset, total=10_569_428):
writer.write(
json.dumps(process_doc(doc), ensure_ascii=False, sort_keys=True) + "\n"
)
from typing import List, Set, Tuple, Iterator, TypeVar, Dict, Optional
from pathlib import Path
from glob import glob
from collections import namedtuple
import pickle
import argparse
import json
import multiprocessing
from hashlib import sha256
from functools import partial
from itertools import islice
import smart_open
from tqdm import tqdm
import sentencepiece as spm
from datasketch import MinHash, MinHashLSH
T = TypeVar("T")
LSHParam = namedtuple("LSHParam", ["threshold", "num_perm", "shingle_length"])
sp_model: Optional[spm.SentencePieceProcessor] = None
def _handle_xz(file_obj, mode):
return lzma.LZMAFile(filename=file_obj, mode=mode, format=lzma.FORMAT_XZ)
smart_open.register_compressor(".xz", _handle_xz)
def batch_iterator(iterator: Iterator[T], batch_size: int = 50) -> Iterator[List[T]]:
"""
Iterates over the given iterator in batches.
iterator: the iterator to iterate over
batch_size: the size of the batch
returns an iterator over batches
"""
iterator = iter(iterator)
while True:
batch = list(islice(iterator, batch_size))
if not batch:
return
yield batch
def tokenize_text(text: str, sp_model) -> List[int]:
"""
Tokenizes the given text using SentencePiece.
text: the text to tokenize
sp_model: the SentencePiece model
returns a list of tokens
>>> tokenize_text("привіт, як справи?", sp_model)
[395, 627, 50096, 524, 5833, 50219]
"""
return sp_model.encode(text)
def get_shingles(tokens: List[int], shingle_length: int) -> Set[Tuple[int, ...]]:
"""
Computes a set of shingles from the given list of tokens.
tokens: the list of tokens
shingle_length: the length of the shingle
returns a set of shingles
>>> get_shingles([1, 2, 3, 4, 5], 2)
{(1, 2), (2, 3), (3, 4), (4, 5)}
>>> get_shingles(tokenize_text("привіт, як справи?", sp_model), 3)
{(395, 627, 50096), (524, 5833, 50219), (627, 50096, 524), (50096, 524, 5833)}
"""
shingles = set()
for i in range(len(tokens) - shingle_length + 1):
shingle = tuple(tokens[i : i + shingle_length])
shingles.add(shingle)
return shingles
def create_minhash(shingles: Set[Tuple[int, ...]], num_perm: int) -> MinHash:
"""
Creates a MinHash of the given set of shingles using the specified number of permutations.
shingles: the set of shingles
num_perm: the number of permutations
returns a MinHash
"""
m = MinHash(num_perm=num_perm)
for shingle in shingles:
m.update(str(shingle).encode("utf8"))
return m
def process_records(
record_str: str, shingle_length: int, num_perm: int
) -> Dict[str, str]:
"""
Computes the MinHash of the text in the given record, and adds it to the LSH index.
"""
global sp_model
record = json.loads(record_str)
text = record["text"]
id_ = record.get("id", record.get("_id", record.get("id_", None)))
if id_ is None:
id_ = sha256(record_str.encode("utf-8")).hexdigest()
tokens = tokenize_text(text, sp_model)
minhash = create_minhash(get_shingles(tokens, shingle_length), num_perm)
return {"id": id_, "minhash": minhash, "tokens": len(tokens)}
def worker_init(
sp_model_name: str,
) -> None:
"""
Initializes the worker process.
sp_model_name: the path to the SentencePiece model
"""
global sp_model
sp_model = spm.SentencePieceProcessor(model_file=sp_model_name)
def main(cli_args: argparse.Namespace) -> None:
"""
Creates an LSH index of shingles from a JSONL file.
"""
cli_args.output_dir.mkdir(parents=True, exist_ok=True)
indexes: Dict[LSHParam, MinHashLSH] = {}
for threshold in cli_args.threshold:
indexes[
LSHParam(
threshold=threshold,
num_perm=cli_args.num_perm,
shingle_length=cli_args.shingle_length,
)
] = MinHashLSH(threshold=threshold, num_perm=cli_args.num_perm)
documents: Dict[str, Dict] = {}
for input_file in glob(cli_args.input_files):
print(f"Processing {input_file}, sit tight...")
with smart_open.open(input_file, "rt", encoding="utf-8") as reader:
with multiprocessing.Pool(
processes=cli_args.num_processes,
initializer=worker_init,
initargs=(cli_args.sp_model,),
) as pool:
for chunk in batch_iterator(
tqdm(reader), batch_size=cli_args.chunk_size
):
if not chunk:
break
for record in pool.imap(
partial(
process_records,
shingle_length=cli_args.shingle_length,
num_perm=cli_args.num_perm,
),
chunk,
):
for index in indexes.values():
index.insert(record["id"], record["minhash"])
documents[record["id"]] = record
# Write index to output file
# print("Writing index to output file...")
# with smart_open.open(cli_args.output_dir, "wb") as fh_out:
# pickle.dump(index, fh_out)
print("Estimating number of unique documents...")
for params, index in indexes.items():
total_tokens: int = 0
filtered_tokens: int = 0
deduped_docs: List[str] = []
for id_, doc in documents.items():
total_tokens += doc["tokens"]
duplicates = index.query(doc["minhash"])
first_duplicate = min(duplicates)
if id_ == first_duplicate:
filtered_tokens += doc["tokens"]
deduped_docs.append(id_)
print(f"Threshold: {params}:")
print(
f"Total number of documents {len(documents)}, total number of unique documents " +
f"{len(deduped_docs)}, ratio {len(deduped_docs) / len(documents)}"
)
print(
f"Total number of tokens {total_tokens}, total number of unique tokens {filtered_tokens}, " +
f"ratio {filtered_tokens / total_tokens}"
)
with cli_args.output_dir.joinpath(
f"deduped_threshold-{params.threshold}.num_perm-{params.num_perm}."
+ f"shingle_length-{params.shingle_length}.tokens_left-{filtered_tokens}.txt"
).open("w", encoding="utf-8") as fh:
for doc_id in deduped_docs:
fh.write(doc_id + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Create LSH index of shingles from a JSONL file"
)
parser.add_argument(
"input_files", help="glob for input JSONL files (archives are welcome)"
)
parser.add_argument(
"output_dir",
help="path to output dir to store ids of deduplicated texts",
type=Path,
)
parser.add_argument("sp_model", help="path to SentencePiece model file")
parser.add_argument(
"--shingle_length", type=int, default=3, help="length of shingles"
)
parser.add_argument(
"--num_perm", type=int, default=128, help="number of permutations for MinHash"
)
parser.add_argument(
"--threshold", type=float, default=[0.5], help="threshold for LSH", nargs="*"
)
parser.add_argument(
"--chunk_size",
type=int,
default=1000,
help="number of records to process in each chunk",
)
parser.add_argument(
"--num_processes",
type=int,
default=multiprocessing.cpu_count(),
help="number of processes to use for parallel processing",
)
args = parser.parse_args()
main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment