Skip to content

Instantly share code, notes, and snippets.

@seanbenhur
Created October 26, 2023 18:11
Show Gist options
  • Save seanbenhur/178fbd401b72753787e29bbe4d7686ae to your computer and use it in GitHub Desktop.
Save seanbenhur/178fbd401b72753787e29bbe4d7686ae to your computer and use it in GitHub Desktop.
import random
import itertools
import os
import shutil
import tempfile
import argparse
import numpy as np
import torch
from tqdm import trange
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, IterableDataset
from transformers import BertConfig
from transformers import BertForMaskedLM
from transformers import BertTokenizer
from transformers import PreTrainedTokenizer
from transformers import PreTrainedModel
def set_seed(args):
if args.seed >= 0:
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
class LineByLineTextDataset(IterableDataset):
def __init__(self, tokenizer: PreTrainedTokenizer, file_path, offsets=None):
assert os.path.isfile(file_path)
print('Loading the dataset...')
self.examples = []
self.tokenizer = tokenizer
self.file_path = file_path
self.offsets = offsets
def process_line(self, worker_id, line):
if len(line) == 0 or line.isspace() or not len(line.split(' ||| ')) == 2:
return None
src, tgt = line.split(' ||| ')
if src.rstrip() == '' or tgt.rstrip() == '':
return None
sent_src, sent_tgt = src.strip().split(), tgt.strip().split()
token_src, token_tgt = [self.tokenizer.tokenize(word) for word in sent_src], [self.tokenizer.tokenize(word) for word in sent_tgt]
wid_src, wid_tgt = [self.tokenizer.convert_tokens_to_ids(x) for x in token_src], [self.tokenizer.convert_tokens_to_ids(x) for x in token_tgt]
ids_src, ids_tgt = self.tokenizer.prepare_for_model(list(itertools.chain(*wid_src)), return_tensors='pt', max_length=512,truncation=True)['input_ids'], self.tokenizer.prepare_for_model(list(itertools.chain(*wid_tgt)), return_tensors='pt', max_length=512,truncation=True)['input_ids']
if len(ids_src.shape) == 1 and ids_src.shape[0] == 2 and len(ids_tgt.shape) == 1 and ids_tgt.shape[0] == 2:
return None
# if len(ids_src[0]) == 2 or len(ids_tgt[0]) == 2:
# return None
bpe2word_map_src = []
for i, word_list in enumerate(token_src):
bpe2word_map_src += [i for x in word_list]
bpe2word_map_tgt = []
for i, word_list in enumerate(token_tgt):
bpe2word_map_tgt += [i for x in word_list]
return (worker_id, ids_src[0], ids_tgt[0], bpe2word_map_src, bpe2word_map_tgt, sent_src, sent_tgt)
def __iter__(self):
if self.offsets is not None:
worker_info = torch.utils.data.get_worker_info()
worker_id = worker_info.id
offset_start = self.offsets[worker_id]
offset_end = self.offsets[worker_id+1] if worker_id+1 < len(self.offsets) else None
else:
offset_start = 0
offset_end = None
worker_id = 0
with open(self.file_path, encoding="utf-8") as f:
f.seek(offset_start)
line = f.readline()
while line:
processed = self.process_line(worker_id, line)
if processed is None:
print(f'Line "{line.strip()}" (offset in bytes: {f.tell()}) is not in the correct format. Skipping...')
empty_tensor = torch.tensor([self.tokenizer.cls_token_id, 999, self.tokenizer.sep_token_id])
empty_sent = ''
yield (worker_id, empty_tensor, empty_tensor, [-1], [-1], empty_sent, empty_sent)
else:
yield processed
if offset_end is not None and f.tell() >= offset_end:
break
line = f.readline()
def find_offsets(filename, num_workers):
if num_workers <= 1:
return None
with open(filename, "r", encoding="utf-8") as f:
size = os.fstat(f.fileno()).st_size
chunk_size = size // num_workers
offsets = [0]
for i in range(1, num_workers):
f.seek(chunk_size * i)
pos = f.tell()
while True:
try:
l=f.readline()
break
except UnicodeDecodeError:
pos -= 1
f.seek(pos)
offsets.append(f.tell())
return offsets
def open_writer_list(filename, num_workers):
writer = open(filename, 'w+', encoding='utf-8')
writers = [writer]
if num_workers > 1:
writers.extend([tempfile.TemporaryFile(mode='w+', encoding='utf-8') for i in range(1, num_workers)])
return writers
def merge_files(writers):
if len(writers) == 1:
writers[0].close()
return
for i, writer in enumerate(writers[1:], 1):
writer.seek(0)
shutil.copyfileobj(writer, writers[0])
writer.close()
writers[0].close()
return
def collate(examples, tokenizer=BertTokenizer.from_pretrained('bert-base-uncased')):
worker_ids, ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt, sents_src, sents_tgt = zip(*examples)
ids_src = pad_sequence(ids_src, batch_first=True, padding_value=tokenizer.pad_token_id)
ids_tgt = pad_sequence(ids_tgt, batch_first=True, padding_value=tokenizer.pad_token_id)
return worker_ids, ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt, sents_src, sents_tgt
from torch.nn.functional import cosine_similarity
def compute_embeddings(model, tokenizer, input_ids, device):
# Set the model to evaluation mode
model.eval()
with torch.no_grad():
input_ids = input_ids.to(device)
outputs = model(input_ids)
# Extract embeddings from the logits (you can choose which layer to use)
logits = outputs.logits # batch_size x seq_len x vocab_size
# Reshape the logits to get token embeddings
token_embeddings = logits.view(-1, logits.size(-1))
return token_embeddings
def word_align(args, model: PreTrainedModel, tokenizer: PreTrainedTokenizer):
# def collate(examples):
# worker_ids, ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt, sents_src, sents_tgt = zip(*examples)
# ids_src = pad_sequence(ids_src, batch_first=True, padding_value=tokenizer.pad_token_id)
# ids_tgt = pad_sequence(ids_tgt, batch_first=True, padding_value=tokenizer.pad_token_id)
# return worker_ids, ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt, sents_src, sents_tgt
offsets = find_offsets(args.data_file, args.num_workers)
dataset = LineByLineTextDataset(tokenizer, file_path=args.data_file, offsets=offsets)
dataloader = DataLoader(
dataset, batch_size=args.batch_size, collate_fn=collate, num_workers=args.num_workers
)
model.to(args.device)
model.eval()
tqdm_iterator = trange(0, desc="Extracting")
cos_sim_writers = open_writer_list(args.output_file, args.num_workers)
# if args.output_prob_file is not None:
# prob_writers = open_writer_list(args.output_prob_file, args.num_workers)
# if args.output_word_file is not None:
# word_writers = open_writer_list(args.output_word_file, args.num_workers)
for batch in dataloader:
with torch.no_grad():
worker_ids, ids_src, ids_tgt, bpe2word_map_src, bpe2word_map_tgt, sents_src, sents_tgt = batch
embeddings_src = compute_embeddings(model, tokenizer, ids_src.to(args.device), args.device)
embeddings_tgt = compute_embeddings(model, tokenizer, ids_tgt.to(args.device), args.device)
#initialize the writers for cosine similairty so we can input the cosine similarities
cos_sim_writers = open_writer_list(args.output_file, args.num_workers)
for worker_id, sent_src, sent_tgt, emb_src, emb_tgt in zip(worker_ids, sents_src, sents_tgt, embeddings_src, embeddings_tgt):
#calculate cosine similairty
print("SENT_SRC: ", sent_src)
cos_sim = cosine_similarity(emb_src.unsqueeze(0), emb_tgt.unsqueeze(0))
#write the cosine similarity to file
cos_sim_writers[worker_id].write(f'{sent_src}-{sent_tgt}: {cos_sim.item()}\n')
tqdm_iterator.update(len(ids_src))
merge_files(cos_sim_writers)
def main():
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--data_file", default="outs.txt", type=str, help="The input data file (a text file)."
)
parser.add_argument(
"--output_file",
default="cos_sim.txt",
type=str,
help="The output file."
)
parser.add_argument("--align_layer", type=int, default=8, help="layer for alignment extraction")
parser.add_argument(
"--extraction", default='softmax', type=str, help='softmax or entmax15'
)
parser.add_argument(
"--softmax_threshold", type=float, default=0.001
)
parser.add_argument(
"--output_prob_file", default=None, type=str, help='The output probability file.'
)
parser.add_argument(
"--output_word_file", default=None, type=str, help='The output word file.'
)
parser.add_argument(
"--model_name_or_path",
default="bert-base-uncased",
type=str,
help="The model checkpoint for weights initialization. Leave None if you want to train a model from scratch.",
)
parser.add_argument(
"--config_name",
default="bert-base-uncased",
type=str,
help="Optional pretrained config name or path if not the same as model_name_or_path. If both are None, initialize a new config.",
)
parser.add_argument(
"--tokenizer_name",
default="bert-base-uncased",
type=str,
help="Optional pretrained tokenizer name or path if not the same as model_name_or_path. If both are None, initialize a new tokenizer.",
)
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
parser.add_argument("--batch_size", default=32, type=int)
parser.add_argument(
"--cache_dir",
default=None,
type=str,
help="Optional directory to store the pre-trained models downloaded from s3 (instead of the default one)",
)
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
parser.add_argument("--num_workers", type=int, default=4, help="Number of workers for data loading")
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
args.device = device
# Set seed
set_seed(args)
config_class, model_class, tokenizer_class = BertConfig, BertForMaskedLM, BertTokenizer
if args.config_name:
config = config_class.from_pretrained(args.config_name, cache_dir=args.cache_dir)
elif args.model_name_or_path:
config = config_class.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
else:
config = config_class()
if args.tokenizer_name:
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir)
elif args.model_name_or_path:
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
else:
raise ValueError(
"You are instantiating a new {} tokenizer. This is not supported, but you can do it from another script, save it,"
"and load it from here, using --tokenizer_name".format(tokenizer_class.__name__)
)
# modeling.PAD_ID = tokenizer.pad_token_id
# modeling.CLS_ID = tokenizer.cls_token_id
# modeling.SEP_ID = tokenizer.sep_token_id
if args.model_name_or_path:
model = model_class.from_pretrained(
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
cache_dir=args.cache_dir,
)
else:
model = model_class(config=config)
word_align(args, model, tokenizer)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment