Skip to content

Instantly share code, notes, and snippets.

@liangfu
Last active January 12, 2023 18:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save liangfu/25183b5871d74b8b97d266daaccad327 to your computer and use it in GitHub Desktop.
Save liangfu/25183b5871d74b8b97d266daaccad327 to your computer and use it in GitHub Desktop.
Benchmark inference time of pretrained AutoGluon Tabular models
# Benchmark inference time of pretrained AutoGluon Tabular models
"""
set -e
DEV_BRANCH=reset-thread-2
git checkout master
/opt/conda/envs/py38/bin/python3 example_dev_tabular.py --sample=1
/opt/conda/envs/py38/bin/python3 example_dev_tabular.py
git checkout $DEV_BRANCH
/opt/conda/envs/py38/bin/python3 example_dev_tabular.py --sample=1
/opt/conda/envs/py38/bin/python3 example_dev_tabular.py
"""
import time
from pygit2 import Repository
from autogluon.tabular import TabularPredictor, TabularDataset
def main(sample=None):
path_prefix = 'https://autogluon.s3.amazonaws.com/datasets/AdultIncomeBinaryClassification/'
path_test = path_prefix + 'test_data.csv'
label = 'class'
test_data = TabularDataset(path_test)
predictor = TabularPredictor.load('trained_models/liangfu/', require_version_match=False)
# Inference time:
if sample is not None:
test_data = test_data.head(sample)
test_data = test_data.drop(labels=[label], axis=1) # delete labels from test data since we wouldn't have them in practice
test_data = predictor.transform_features(test_data)
branch = Repository('.').head.shorthand
print(f"config: len(test_data)={len(test_data)}, branch={branch}")
predictor.persist_models()
models = predictor.get_model_names()
for m in models:
avg = 0
n = 100
for _ in range(n):
tic = time.time()
y_pred = predictor.predict(test_data, model=m, transform_features=False)
t_ms = (time.time() - tic) * 1000
avg += t_ms
avg = avg/n
print(f'Average: {avg:.1f} ms ({m})')
if __name__=="__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--sample", help="sample size for benchmarking", type=int)
args = parser.parse_args()
main(args.sample)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment