Skip to content

Instantly share code, notes, and snippets.

@128f
Last active July 22, 2023 16:51
Show Gist options
  • Save 128f/da53ccbb80797cb8d3d20d25b0d4bac1 to your computer and use it in GitHub Desktop.
Save 128f/da53ccbb80797cb8d3d20d25b0d4bac1 to your computer and use it in GitHub Desktop.
Multithread wiki data encoder
#!/home/ubuntu/anaconda3/envs/clip/bin/python
import untangle
obj = untangle.parse('dump.xml')
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
from ctypes import c_int
import torch
import json
from time import sleep, time
import os
import torch.multiprocessing as mp
def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i:i + n]
class Encoder(mp.Process):
"""
A class for generating clip embeddings and writing them to a file
"""
def __init__(self, process_number, queue, completed_event, completed_count):
"""
queue should be the common queue shared among processes
filename should be unique to this class
"""
super(Encoder, self).__init__()
self.process_number = process_number
self.queue = queue
self.datafile = "documents/document%d.json"%self.process_number
self.completed_event = completed_event
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.completed_count = completed_count
def load_model(self):
self.tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
self.model = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base").to(self.device)
def generate_embeddings(self, data):
"""
Generate a clip embedding from chunks of a file
"""
text_block = data['text']
embeddings = []
for chunk in chunks(text_block, 512):
try:
with torch.no_grad():
input_ids = self.tokenizer(data["title"] + " [SEP] " + chunk, return_tensors="pt").to(self.device)["input_ids"]
embeddings.append({ "text": chunk, "embedding": self.model(input_ids).pooler_output.tolist() })
except Exception as e:
print("Failed to create embedding ", e)
# print("produced ", len(embeddings), " embeddings")
return embeddings
def write_to_datafile(self, title, chunks):
"""
Write to our named datafile
"""
for chunk in chunks:
obj = {
'title': title,
'text': chunk['text'],
'embedding': chunk['embedding']
}
with open(self.datafile, 'a') as file:
json.dump(obj, file)
file.write('\n')
def process(self, obj):
# print("processing ", obj['title'])
try:
chunks = self.generate_embeddings(obj)
self.write_to_datafile(obj['title'], chunks)
except Exception as e:
print("Failed to process ", e)
def run(self):
"""
Pull from the queue until we're told not to
Continuously write to a datafile instead of storing in memory
because we may be processing a lot of images
"""
self.load_model()
while not (self.queue.empty() and self.completed_event.is_set()):
data = self.queue.get()
self.process(data)
self.completed_count.get_obj().value += 1
def collect_data():
datafiles = os.listdir('text/AA')
data = []
for f in datafiles:
with open(os.path.join('text/AA', f)) as file:
for line in file:
page = json.loads(line)
title = page["title"].lower()
if "talk" in title:
continue
if "file" in title:
continue
if "user" in title:
continue
if "template" in title:
continue
if "development resources" in title:
continue
if "category" in title:
continue
if "wiki" in title:
continue
if "disambiguation" in title:
continue
if "removed" in title:
continue
data.append(page)
print("read ", len(data), " items")
return data
def multithread():
print("preparing initial objects...")
queue = mp.Queue()
completed_event = mp.Event()
completed_count = mp.Value(c_int)
print("reading datafiles")
[queue.put(data) for data in collect_data()]
# start reading them queues
print("starting worker threads")
worker_instances = 20
encoders = [Encoder(i, queue, completed_event, completed_count) for i in range(worker_instances)]
for c in encoders:
c.start()
start_time = time()
print("starting monitor loop")
while not (queue.empty() and completed_event.is_set()):
diff_time = time() - start_time
count = completed_count.get_obj().value
print("Processed %s %.2f/s"%(count, count/diff_time))
sleep(4)
count = completed_count.get_obj().value
print("Completed %s"%count)
completed_event.set()
for c in encoders:
c.join()
def singlethread():
"""
A simple test for debugging
"""
queue = mp.Queue()
completed_event = mp.Event()
completed_count = mp.Value(c_int)
c = Encoder(0, queue, completed_event, completed_count)
[queue.put(data) for data in collect_data()]
c.run()
if __name__ == '__main__':
print("starting...")
mp.set_start_method('spawn')
# singlethread()
multithread()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment