-
-
Save geoffreyangus/638bb5a200375568676ad5e610348085 to your computer and use it in GitHub Desktop.
Torchscript Benchmarking Script
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import contextlib | |
import logging | |
import os | |
import random | |
import time | |
import yaml | |
from collections import defaultdict | |
import pandas as pd | |
import torch | |
from tqdm import tqdm | |
from ludwig.api import LudwigModel | |
from ludwig.models.inference import InferenceModule, to_inference_module_input_from_dataframe | |
@contextlib.contextmanager | |
def timeit(duration=None): | |
start_t = time.time() | |
try: | |
yield | |
finally: | |
if duration is not None: | |
duration.append(round(time.time() - start_t, 5)) | |
MODE = 'AGNEWS' | |
if MODE == 'TITANIC': | |
experiment_name_prefix = 'titanic' | |
dataset_path = '/home/ray/titanic_train.csv' | |
num_test_batches = 100 | |
num_warmup_batches = 100 | |
config_str = """ | |
input_features: | |
- name: Pclass | |
type: category | |
- name: Sex | |
type: category | |
- name: Age | |
type: number | |
preprocessing: | |
missing_value_strategy: fill_with_mean | |
- name: SibSp | |
type: number | |
- name: Parch | |
type: number | |
preprocessing: | |
normalization: zscore | |
- name: Fare | |
type: number | |
preprocessing: | |
missing_value_strategy: fill_with_mean | |
- name: Embarked | |
type: category | |
output_features: | |
- name: Survived | |
type: binary | |
trainer: | |
batch_size: 128 | |
train_steps: 4 | |
steps_per_checkpoint: 2 | |
early_stop: 1 | |
backend: | |
type: ray | |
processor: | |
type: dask | |
trainer: | |
num_workers: 1 | |
""" | |
elif MODE == 'AGNEWS': | |
experiment_name_prefix = 'agnews' | |
dataset_path = '/home/ray/agnews_tiny.csv' | |
num_test_batches = 10 | |
num_warmup_batches = 100 | |
config_str = """ | |
input_features: | |
- name: description | |
type: text | |
output_features: | |
- name: class_index | |
type: category | |
trainer: | |
batch_size: 1024 | |
train_steps: 1 | |
steps_per_checkpoint: 1 | |
early_stop: 0 | |
backend: | |
type: ray | |
processor: | |
type: dask | |
trainer: | |
num_workers: 1 | |
""" | |
else: | |
raise ValueError(f'Invalid MODE: {MODE}') | |
def app() -> None: | |
config = yaml.safe_load(config_str) | |
model = LudwigModel( | |
config=config, | |
logging_level=logging.INFO | |
) | |
_, _, output_dir = model.train( | |
dataset=dataset_path, | |
experiment_name=f'{MODE}', | |
model_name='local_test', | |
output_directory='/home/ray/results/', | |
) | |
results = defaultdict(list) | |
devices = ['cpu', 'cuda'] | |
for device in devices: | |
model_path = os.path.join(output_dir, 'model') | |
model.save_torchscript(model_path, device=device) | |
print('model_path:', model_path) | |
# Use local files instead of MLflow artifacts | |
def load_model(model_path, model_type='ludwig_model'): | |
if model_type == 'ludwig_model': | |
return LudwigModel.load(model_path, backend='local') | |
elif model_type == 'single_module': | |
inference_module = torch.jit.load(os.path.join(model_path, 'single_module.pt')) | |
return inference_module | |
elif model_type == 'pipeline': | |
inference_module = InferenceModule.from_directory(model_path, device=device) | |
# inference_module = torch.jit.script(inference_module) # compare as single torchscript module | |
return inference_module | |
else: | |
raise ValueError(f'Invalid model_type: {model_type}') | |
test_df = pd.read_csv(dataset_path) | |
batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] | |
loaded_models = {} | |
warmup_batch_sizes = batch_sizes | |
for model_type in ['ludwig_model', 'single_module', 'pipeline']: | |
is_torchscript = model_type in {'single_module', 'pipeline'} | |
loaded_model = load_model(model_path, model_type=model_type) | |
# Only need to warmup torchscript models | |
if is_torchscript: | |
print(f'Warming up model with type {type(loaded_model)} with {num_warmup_batches} batches of size random.choice({warmup_batch_sizes})') | |
for _ in tqdm(range(num_warmup_batches)): | |
warmup_batch_size = random.choice(warmup_batch_sizes) | |
warmup_batch = create_batch(test_df, model.config, device, warmup_batch_size, is_torchscript) | |
model_predict(loaded_model, warmup_batch, is_torchscript) | |
print('Warmup done') | |
loaded_models[model_type] = loaded_model | |
for batch_size in batch_sizes: | |
for model_type, loaded_model in loaded_models.items(): | |
is_torchscript = model_type in {'single_module', 'pipeline'} | |
test_batches = [] | |
for _ in range(num_test_batches): | |
test_batch = create_batch(test_df, model.config, device, batch_size, is_torchscript) | |
test_batches.append(test_batch) | |
print(f'benchmarking model with type "{model_type}" with batch size {batch_size}') | |
duration = [] | |
for test_batch in tqdm(test_batches, total=len(test_batches)): | |
with timeit(duration): | |
model_predict(loaded_model, test_batch, is_torchscript) | |
print(f'results: {duration}') | |
for d in duration: | |
results['name'].append(model_type) | |
results['device'].append(device) | |
results['num_batches'].append(num_test_batches) | |
results['batch_size'].append(batch_size) | |
results['duration'].append(d) | |
output_df_path = os.path.join('/home/ray/', f'{MODE}_results.csv') | |
output_df = pd.DataFrame(dict(results)).to_csv(output_df_path, index=False) | |
print(f'Complete results saved to {output_df_path}') | |
print(f'Finished experiment for {dataset_path}') | |
print('model_path: ', model_path) | |
print('Done') | |
def create_batch(df, config, device, batch_size, is_torchscript): | |
batch = df.sample(n=batch_size, replace=True) | |
if is_torchscript: | |
return to_inference_module_input_from_dataframe(batch, config, load_paths=True, device=device) | |
else: | |
return batch | |
def model_predict(model, batch, is_torchscript): | |
if is_torchscript: | |
return model(batch) | |
else: | |
return model.predict(batch) | |
if __name__ == "__main__": | |
app() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment