Skip to content

Instantly share code, notes, and snippets.

@nialloriordan
Last active January 10, 2021 12:56
Show Gist options
  • Save nialloriordan/a8bdfce9145445a4a96916efc758e472 to your computer and use it in GitHub Desktop.
Save nialloriordan/a8bdfce9145445a4a96916efc758e472 to your computer and use it in GitHub Desktop.
This module is used for running zero shot classification with multiple GPUs
"""This module is used for running zero shot classification with multiple GPUs"""
import pandas as pd
import numpy as np
from transformers import pipeline
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers import set_seed
from tqdm.notebook import tqdm
import time
import math
from multiprocessing import Pool, current_process, Queue
class multi_gpu_inference:
def __init__(
self,
df,
candidate_labels,
text_column="text_content_cleaned",
multi_class=True,
transformer_model_name="facebook/bart-large-mnli",
batch_size=4,
num_gpus=8,
proc_per_gpu=1,
seed_value=42,
):
self.df = df
self.candidate_labels = candidate_labels
self.text_column = text_column
self.multi_class = multi_class
self.batch_size = batch_size
self.num_gpus = num_gpus
self.proc_per_gpu = proc_per_gpu
self.seed_value = seed_value
self.transformer_model_name = transformer_model_name
self.tokenizer = AutoTokenizer.from_pretrained(self.transformer_model_name)
def _load_model(self, gpu_id):
"""Load model into memory"""
set_seed(seed_value)
classifier = pipeline(
"zero-shot-classification",
model=self.transformer_model_name,
tokenizer=self.tokenizer,
framework="pt",
device=gpu_id, # enable GPU
)
return classifier
def _split_data_batches(self, text_docs: list) -> list:
"""Split data into batches of a specified size"""
data_batches = np.array_split(
text_docs,
math.ceil(len(text_docs) / self.batch_size),
)
return data_batches
def _predict_data_batches(self, data_chunks, gpu_id, classifier) -> list:
"""Make predictions in batches"""
results = []
text_desc = (
"Classifying with CPU" if gpu_id == -1 else f"Classifying with GPU {gpu_id}"
)
for data in tqdm(
data_chunks,
total=len(data_chunks),
desc=text_desc,
):
chunk_size = len(data)
result = classifier(list(data), self.candidate_labels, multi_class=True)
results.extend([result]) if chunk_size == 1 else results.extend(result)
return results
def _convert_model_results_df(self, results) -> pd.DataFrame:
"""Convert model results into a pandas dataframe"""
df = pd.DataFrame()
for result in results:
df_result = pd.DataFrame(
data=[result["scores"] + [result["sequence"]]],
columns=result["labels"] + ["sequence"],
)
df = df.append(df_result)
df.reset_index(drop=True, inplace=True)
df.columns = df.columns.str.replace(" ", "_")
return df
def _model_results(self, text_docs: list) -> pd.DataFrame:
"""Determine model results"""
gpu_id = queue.get() # get GPU not already in use
try:
# load model on GPU
classifier = self._load_model(gpu_id)
# split data into batches
data_chunks = self._split_data_batches(text_docs)
# label data in batches
results = self._predict_data_batches(data_chunks, gpu_id, classifier)
# convert results to dataframe
df = self._convert_model_results_df(results)
finally:
queue.put(gpu_id)
return df
def _split_data_chunks_GPUs(self):
"""Split data into equal chunks for each GPU"""
data_chunks = np.array_split(
self.df[self.text_column].values, min(len(self.df), self.num_gpus)
)
return data_chunks
def run(self):
global queue # set queue as a global variable (it cannot be set as a self param)
queue = Queue()
# initialize the queue with the GPU ids
for gpu_id in range(self.num_gpus):
for _ in range(self.proc_per_gpu):
queue.put(gpu_id)
pool = Pool(processes=self.proc_per_gpu * self.num_gpus)
# split data into chunks for each GPU.
# A different function could be used so that each GPU has the same byte size of data rather than the same number of data points.
data_chunks = self._split_data_chunks_GPUs()
# obtain results for each GPU
results = []
for _ in tqdm(
pool.imap(self._model_results, data_chunks),
total=len(data_chunks),
desc="GPU process",
):
results.append(_)
# close pool and join results
pool.close()
pool.join()
df_results = pd.concat(results)
return df_results
@nialloriordan
Copy link
Author

Run zero shot classifiation with multiple GPUs as follows:

gpu_inference = multi_gpu_inference(
    df, # pandas dataframe
    labels, # list of labels
    num_gpus=8, # number of GPUs
    batch_size=4, # batch size
    text_column=text_column, # text column to use from `df`
)
df_results = gpu_inference.run()

@nialloriordan
Copy link
Author

To speed up initialisation the model could be initially saved to disk via classifier.save_pretrained and then each GPU could load the model from disk.

@nialloriordan
Copy link
Author

The run function in gpu_inference can be extended for other tasks other than zero shot classification that cannot easily be run with multiple GPUs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment