Created
April 20, 2023 12:12
-
-
Save napsternxg/18deca3a9a4dded4c20726308a269a2b to your computer and use it in GitHub Desktop.
Gen Text Embeddings
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pathlib import Path | |
import torch | |
from transformers import CLIPProcessor, CLIPTextModelWithProjection | |
from accelerate import Accelerator | |
from datasets import Dataset | |
import pandas as pd | |
import numpy as np | |
from tqdm.auto import tqdm | |
from utils import timeit | |
accelerator = Accelerator() | |
@timeit | |
def run_inference(model_type, data_path, batch_size=32, use_accelerate=False, testing=True): | |
text_model = CLIPTextModelWithProjection.from_pretrained(model_type) | |
processor = CLIPProcessor.from_pretrained(model_type) | |
print(f"Loaded {model_type=}") | |
df = pd.read_csv(data_path, sep="\t", index_col=0, header=None, names=["text"]).fillna("") | |
if testing: | |
df = df.head(1000) | |
print(f"Loaded {df.shape=}") | |
max_length = text_model.config.max_position_embeddings | |
def encode(examples): | |
return processor.tokenizer(examples["text"], truncation=True, padding="max_length", max_length=max_length) | |
with accelerator.local_main_process_first(): | |
dataset = Dataset.from_pandas(df).map(encode, batched=True) | |
print(f"Converted to {dataset=}") | |
dataset.set_format(type="torch", columns=["input_ids", "attention_mask"]) | |
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False) | |
text_model.eval() | |
text_model, dataloader = accelerator.prepare(text_model, dataloader) | |
all_outs = [] | |
embeddings = [] | |
with torch.no_grad(): | |
for batch in tqdm(dataloader, desc="run inference", disable=not accelerator.is_local_main_process): | |
output = text_model(**batch) | |
# For eval always call accelerator.gather_for_metrics so that output is same as datasize without padding | |
# https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.gather_for_metrics | |
output = accelerator.gather_for_metrics(output) if use_accelerate else output | |
# all_outs.append(output) | |
# print(len(all_outs)) | |
output_embeddings = output.text_embeds.detach().cpu().numpy() | |
print(f"{output_embeddings.shape=}") | |
embeddings.append(output_embeddings) | |
embeddings = np.vstack(embeddings) | |
print(f"{embeddings.shape=}") | |
return embeddings | |
@timeit | |
def main(args): | |
model_type = Path(args.model_type).expanduser() | |
data_path = args.data_path | |
embeddings = run_inference( | |
model_type, args.data_path, batch_size=args.batch_size, | |
use_accelerate=True, testing=False | |
) | |
if accelerator.is_local_main_process: | |
print(f"Output embeddings: {embeddings.shape}") | |
output_path = data_path.replace(".tsv", f".{model_type.name}.npy") | |
print(f"Saving to: {output_path}") | |
np.save(output_path, embeddings) | |
print(f"Saved to: {output_path}") | |
def get_parser(): | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--model_type", help="Model type for SentenceTransformer" | |
) | |
parser.add_argument( | |
"--data_path", | |
help="data_path for data frame with first col as index and second col as text", | |
) | |
parser.add_argument("--batch_size", default=512*8, type=int) | |
return parser | |
if __name__ == "__main__": | |
parser = get_parser() | |
args = parser.parse_args() | |
print(args) | |
main(args) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from transformers import AutoModel, AutoTokenizer | |
from sentence_transformers import SentenceTransformer | |
import pandas as pd | |
import numpy as np | |
from pathlib import Path | |
from utils import timeit | |
@timeit | |
def main(args): | |
model_type = Path(args.model_type).expanduser() | |
model = SentenceTransformer(model_type) | |
data_path = args.data_path | |
df_text = timeit(pd.read_csv)(data_path, sep="\t", index_col=0, header=None, names=["text"]) | |
print(f"df_text={df_text.shape}") | |
pool = model.start_multi_process_pool(target_devices=[f"cuda:{i}" for i in range(torch.cuda.device_count())]) | |
print(f"pool={pool}") | |
embeddings = model.encode_multi_process(df_text["text"].values, pool=pool, batch_size=args.batch_size) | |
model.stop_multi_process_pool(pool) | |
print(f"embeddings={embeddings.shape}") | |
output_path = data_path.replace(".tsv", f".{model_type.name}.npy") | |
print(f"Saving to: {output_path}") | |
np.save(output_path, embeddings) | |
print(f"Saving to: {output_path}") | |
def get_parser(): | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--model_type", help="Model type for SentenceTransformer" | |
) | |
parser.add_argument( | |
"--data_path", | |
help="data_path for data frame with first col as index and second col as text", | |
) | |
parser.add_argument("--batch_size", default=512*8, type=int) | |
return parser | |
if __name__ == "__main__": | |
parser = get_parser() | |
args = parser.parse_args() | |
main(args) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment