Skip to content

Instantly share code, notes, and snippets.

@macleginn
Last active September 8, 2023 12:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save macleginn/9eb59ddd3be1c0d5510d05cc95a218be to your computer and use it in GitHub Desktop.
Save macleginn/9eb59ddd3be1c0d5510d05cc95a218be to your computer and use it in GitHub Desktop.
XSBERT worker process
import os
import sys
import pickle
import requests
import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.models import Pooling
from sentence_transformers import util
from xsbert import models
OUTPUT_DIR = os.path.join('..', 'qqp_attributions')
MODEL_NAME = 'sentence-transformers/all-mpnet-base-v2'
queue_port = int(sys.argv[1])
cuda_device_number = int(sys.argv[2])
assert cuda_device_number in list(
range(0, 9)), f'Wrong CUDA device id: {cuda_device_number}'
encoder = SentenceTransformer(MODEL_NAME)
encoder.cuda()
model = models.ReferenceTransformer(MODEL_NAME)
pooling = Pooling(model.get_word_embedding_dimension())
explainer = models.XSMPNet(modules=[model, pooling])
explainer.to(torch.device(f'cuda:{cuda_device_number}'))
while True:
r = requests.get(f'http://localhost:{queue_port}')
r.encoding = 'utf-8'
data = r.json()
if data == {}:
print('End of input. Shutting down.')
break
out_path = os.path.join(OUTPUT_DIR, f'{data["id"]}.pickle')
if os.path.exists(out_path):
continue
with torch.no_grad():
embeddings = encoder.encode(
[data['question1'], data['question2']], convert_to_tensor=True)
cos_sim = util.pairwise_cos_sim(embeddings[0:1], embeddings[1:]).item()
print('Sentence 1:', data['question1'])
print('Sentence 2:', data['question2'])
print(f'Similarity: {cos_sim}')
explainer.reset_attribution()
explainer.init_attribution_to_layer(idx=8, N_steps=100)
try:
A, ta, tb, *_ = explainer.explain_similarity(
data['question1'],
data['question2'],
return_score=True,
sim_measure='cos'
)
with open(out_path, 'wb') as out:
pickle.dump({
'id': data['id'],
'tokens_a': ta,
'tokens_b': tb,
'similarity': cos_sim,
'A': A
}, out)
# Something bad happened; most probably CUDA OOM error.
except Exception as e:
print(e)
continue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment