Skip to content

Instantly share code, notes, and snippets.

@19h
Created November 3, 2023 23:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save 19h/4b65a2351f0a7ac5aa7479db5f81766a to your computer and use it in GitHub Desktop.
Save 19h/4b65a2351f0a7ac5aa7479db5f81766a to your computer and use it in GitHub Desktop.
This Python code efficiently extracts sentence embeddings from a CSV of news articles using a pretrained BERT model. It batches titles, generates embeddings, serializes them, and writes the embeddings and metadata to a new CSV file.
import csv
import json
import torch
from tqdm import tqdm
from transformers import AutoModel, BertTokenizerFast
import ctypes as ct
csv.field_size_limit(int(ct.c_ulong(-1).value // 2))
model = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en', trust_remote_code=True)
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()
if device.type == 'cuda':
model = model.half()
with open('all-the-news-2-1.csv', 'r') as f:
reader = csv.DictReader(f)
with open('atn-embd.csv', 'w') as fout:
writer = csv.DictWriter(fout, fieldnames=['date', 'year', 'month', 'day', 'author', 'section', 'publication', 'title', 'embedding'])
writer.writeheader()
batch_size = 64
texts = []
rows = []
for i, row in enumerate(tqdm(reader)):
date = row['date']
year = row['year']
month = row['month']
day = row['day']
author = row['author']
title = row['title']
article = row['article']
section = row['section']
publication = row['publication']
text = title + '\n' + article
texts.append(text)
rows.append({
'date': date,
'year': year,
'month': month,
'day': day,
'author': author,
'section': section,
'publication': publication,
'title': title,
})
if len(texts) == batch_size:
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors='pt').to(device)["input_ids"]
with torch.no_grad():
outputs = model(inputs)
embeddings = outputs.last_hidden_state[:, 0, :].float().cpu().numpy()
for j, row in enumerate(rows):
writer.writerow({
'date': row['date'],
'year': row['year'],
'month': row['month'],
'day': row['day'],
'author': row['author'],
'section': row['section'],
'publication': row['publication'],
'title': row['title'],
'embedding': json.dumps(embeddings[j].tolist())
})
texts = []
rows = []
if len(texts) > 0:
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors='pt').to(device)
with torch.no_grad():
outputs = model(inputs)
embeddings = outputs.last_hidden_state[:, 0, :].float().cpu().numpy()
for j, row in enumerate(rows):
writer.writerow({
'date': row['date'],
'year': row['year'],
'month': row['month'],
'day': row['day'],
'author': row['author'],
'section': row['section'],
'publication': row['publication'],
'title': row['title'],
'embedding': json.dumps(embeddings[j].tolist())
})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment