Last active
December 7, 2022 22:53
-
-
Save mallamanis/ce1a3624b6d1a9ec9b6966e6b7181dcd to your computer and use it in GitHub Desktop.
Compute the duplicate clusters for https://huggingface.co/datasets/lvwerra/github-code
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import multiprocessing as mp | |
import re | |
from collections import defaultdict | |
from typing import List, Optional, Set | |
from datasets import load_dataset | |
from datasketch import MinHash, MinHashLSH, minhash | |
from dpu_utils.utils.iterators import ThreadedIterator | |
from tqdm import tqdm | |
class DuplicationIndex: | |
def __init__( | |
self, | |
*, | |
duplication_jaccard_threshold: float = 0.85, | |
num_perm: int = 256, | |
min_num_tokens: int = 10, | |
): | |
self.__duplication_jaccard_threshold = duplication_jaccard_threshold | |
self.__num_perm = num_perm | |
self.__min_num_tokens = min_num_tokens | |
self.__index = MinHashLSH( | |
threshold=self.__duplication_jaccard_threshold, num_perm=self.__num_perm | |
) | |
self.__duplicate_clusters = defaultdict(set) | |
def get_min_hash(self, tokens: List[str]) -> Optional[MinHash]: | |
if len(tokens) < self.__min_num_tokens: | |
return None | |
min_hash = MinHash(num_perm=self.__num_perm) | |
for token in set(tokens): | |
min_hash.update(token.encode()) | |
return min_hash | |
def add(self, filename: str, min_hash: MinHash) -> None: | |
close_duplicates = self.__index.query(min_hash) | |
if filename in self.__index.keys: | |
print("Duplicate key %s" % filename) | |
return | |
self.__index.insert(filename, min_hash) | |
if len(close_duplicates) > 0: | |
# print("`%s` duplicate of: %s" % (filename, close_duplicates)) | |
for base_duplicate in close_duplicates: | |
if base_duplicate in self.__duplicate_clusters: | |
self.__duplicate_clusters[base_duplicate].add(filename) | |
break | |
else: | |
self.__duplicate_clusters[close_duplicates[0]].add(filename) | |
def save(self, filepath) -> None: | |
duplicate_clusters = [] | |
for base, duplicates in self.__duplicate_clusters.items(): | |
duplicate_clusters.append(list(duplicates) + [base]) | |
with open(filepath, "w") as f: | |
json.dump(duplicate_clusters, f) | |
if __name__ == "__main__": | |
import sys | |
languages = sys.argv[1].split(",") | |
di = DuplicationIndex() | |
# A very approximate tokenization for most programming languages | |
NON_ALPHA = re.compile("[^A-Za-z_0-9]") | |
def compute_min_hash(element): | |
min_hash = di.get_min_hash( | |
[t for t in NON_ALPHA.split(element["code"]) if len(t.strip()) > 0] | |
) | |
if min_hash is not None: | |
return element["repo_name"] + "::" + element["path"], min_hash | |
ds = iter( | |
load_dataset( | |
"lvwerra/github-code", streaming=True, split="train", languages=languages | |
) | |
) | |
def minhash_iter(): | |
with mp.Pool() as pool: | |
for data in pool.imap_unordered( | |
compute_min_hash, | |
ThreadedIterator(ds, max_queue_size=10000), | |
chunksize=100, | |
): | |
if data is not None: | |
yield data | |
for filename, min_hash in tqdm( | |
ThreadedIterator(minhash_iter(), max_queue_size=100) | |
): | |
di.add(filename, min_hash) | |
# Returns a List[Cluster] where Cluster is List[str] with the filenames. | |
di.save(f"duplicates-{'-'.join(languages)}.json") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment