-
-
Save 128f/da53ccbb80797cb8d3d20d25b0d4bac1 to your computer and use it in GitHub Desktop.
Multithread wiki data encoder
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/home/ubuntu/anaconda3/envs/clip/bin/python | |
import untangle | |
obj = untangle.parse('dump.xml') | |
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer | |
from ctypes import c_int | |
import torch | |
import json | |
from time import sleep, time | |
import os | |
import torch.multiprocessing as mp | |
def chunks(lst, n): | |
"""Yield successive n-sized chunks from lst.""" | |
for i in range(0, len(lst), n): | |
yield lst[i:i + n] | |
class Encoder(mp.Process): | |
""" | |
A class for generating clip embeddings and writing them to a file | |
""" | |
def __init__(self, process_number, queue, completed_event, completed_count): | |
""" | |
queue should be the common queue shared among processes | |
filename should be unique to this class | |
""" | |
super(Encoder, self).__init__() | |
self.process_number = process_number | |
self.queue = queue | |
self.datafile = "documents/document%d.json"%self.process_number | |
self.completed_event = completed_event | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.completed_count = completed_count | |
def load_model(self): | |
self.tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base") | |
self.model = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base").to(self.device) | |
def generate_embeddings(self, data): | |
""" | |
Generate a clip embedding from chunks of a file | |
""" | |
text_block = data['text'] | |
embeddings = [] | |
for chunk in chunks(text_block, 512): | |
try: | |
with torch.no_grad(): | |
input_ids = self.tokenizer(data["title"] + " [SEP] " + chunk, return_tensors="pt").to(self.device)["input_ids"] | |
embeddings.append({ "text": chunk, "embedding": self.model(input_ids).pooler_output.tolist() }) | |
except Exception as e: | |
print("Failed to create embedding ", e) | |
# print("produced ", len(embeddings), " embeddings") | |
return embeddings | |
def write_to_datafile(self, title, chunks): | |
""" | |
Write to our named datafile | |
""" | |
for chunk in chunks: | |
obj = { | |
'title': title, | |
'text': chunk['text'], | |
'embedding': chunk['embedding'] | |
} | |
with open(self.datafile, 'a') as file: | |
json.dump(obj, file) | |
file.write('\n') | |
def process(self, obj): | |
# print("processing ", obj['title']) | |
try: | |
chunks = self.generate_embeddings(obj) | |
self.write_to_datafile(obj['title'], chunks) | |
except Exception as e: | |
print("Failed to process ", e) | |
def run(self): | |
""" | |
Pull from the queue until we're told not to | |
Continuously write to a datafile instead of storing in memory | |
because we may be processing a lot of images | |
""" | |
self.load_model() | |
while not (self.queue.empty() and self.completed_event.is_set()): | |
data = self.queue.get() | |
self.process(data) | |
self.completed_count.get_obj().value += 1 | |
def collect_data(): | |
datafiles = os.listdir('text/AA') | |
data = [] | |
for f in datafiles: | |
with open(os.path.join('text/AA', f)) as file: | |
for line in file: | |
page = json.loads(line) | |
title = page["title"].lower() | |
if "talk" in title: | |
continue | |
if "file" in title: | |
continue | |
if "user" in title: | |
continue | |
if "template" in title: | |
continue | |
if "development resources" in title: | |
continue | |
if "category" in title: | |
continue | |
if "wiki" in title: | |
continue | |
if "disambiguation" in title: | |
continue | |
if "removed" in title: | |
continue | |
data.append(page) | |
print("read ", len(data), " items") | |
return data | |
def multithread(): | |
print("preparing initial objects...") | |
queue = mp.Queue() | |
completed_event = mp.Event() | |
completed_count = mp.Value(c_int) | |
print("reading datafiles") | |
[queue.put(data) for data in collect_data()] | |
# start reading them queues | |
print("starting worker threads") | |
worker_instances = 20 | |
encoders = [Encoder(i, queue, completed_event, completed_count) for i in range(worker_instances)] | |
for c in encoders: | |
c.start() | |
start_time = time() | |
print("starting monitor loop") | |
while not (queue.empty() and completed_event.is_set()): | |
diff_time = time() - start_time | |
count = completed_count.get_obj().value | |
print("Processed %s %.2f/s"%(count, count/diff_time)) | |
sleep(4) | |
count = completed_count.get_obj().value | |
print("Completed %s"%count) | |
completed_event.set() | |
for c in encoders: | |
c.join() | |
def singlethread(): | |
""" | |
A simple test for debugging | |
""" | |
queue = mp.Queue() | |
completed_event = mp.Event() | |
completed_count = mp.Value(c_int) | |
c = Encoder(0, queue, completed_event, completed_count) | |
[queue.put(data) for data in collect_data()] | |
c.run() | |
if __name__ == '__main__': | |
print("starting...") | |
mp.set_start_method('spawn') | |
# singlethread() | |
multithread() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment