Skip to content

Instantly share code, notes, and snippets.

@geoffreyangus
Created July 19, 2022 23:42
Show Gist options
  • Save geoffreyangus/e65c1c93c4b3b65518fbf6cf3f6c7316 to your computer and use it in GitHub Desktop.
Save geoffreyangus/e65c1c93c4b3b65518fbf6cf3f6c7316 to your computer and use it in GitHub Desktop.
This Gist demonstrates how to perform inference using a Torchscripted LudwigModel.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "197d0bdc-2481-419a-ada0-101f1ee54582",
"metadata": {},
"source": [
"# Torchscripted LudwigModel Example\n",
"\n",
"This notebook demonstrates how one might use a torchscripted LudwigModel for inference."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "aa6e0c56-cafd-43c0-bd61-4abafc58e761",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/geoffreyangus/repositories/predibase/ludwig/venv38/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import os\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"from sklearn.metrics import roc_auc_score, accuracy_score\n",
"from tqdm import tqdm\n",
"\n",
"from ludwig.utils.inference_utils import to_inference_module_input_from_dataframe\n",
"from ludwig.utils.data_utils import load_json\n",
"\n",
"SRC_DIR = '/Users/geoffreyangus/Downloads/'"
]
},
{
"cell_type": "markdown",
"id": "2d3e16de-68cd-4811-8ebc-e5575aa9cf6b",
"metadata": {},
"source": [
"**Step 1. Loading the dataset and the inference module**\n",
"\n",
"The torchscript model can be loaded using `torch.jit.load`."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "8aa45deb-e31e-4704-9cb0-054dcb766670",
"metadata": {},
"outputs": [],
"source": [
"df = pd.read_csv(os.path.join(SRC_DIR, 'data/test.tsv'), sep='\\t')\n",
"inference_module = torch.jit.load(os.path.join(SRC_DIR, 'export_torchscript/inference_module'))"
]
},
{
"cell_type": "markdown",
"id": "cf157e45-89d4-4409-85cd-2ac969f51238",
"metadata": {},
"source": [
"**Step 2: Performing inference on a sample input**\n",
"\n",
"For Python backends, Ludwig provides a helper function `to_inference_module_input_from_dataframe`, which takes as input the config and a raw DataFrame formatted exactly as it was during training.\n",
"\n",
"We can use this helper function to create the `Dict[List]` of inputs expected by the inference module. Then, we can pass the resulting object into the inference module and get our first set of predictions."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "2369ec9a-4c34-4778-a151-134cef66aa44",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/geoffreyangus/repositories/predibase/ludwig/venv38/lib/python3.8/site-packages/torch/nn/modules/module.py:1110: UserWarning: Using padding='same' with even kernel lengths and odd dilation may require a zero-padded copy of the input be created (Triggered internally at /Users/distiller/project/pytorch/aten/src/ATen/native/Convolution.cpp:744.)\n",
" return forward_call(*input, **kwargs)\n"
]
},
{
"data": {
"text/plain": [
"{'label': {'predictions': tensor([False, False, False, False, False]),\n",
" 'probabilities': tensor([[1.0000e+00, 1.9320e-17],\n",
" [1.0000e+00, 0.0000e+00],\n",
" [1.0000e+00, 2.5421e-14],\n",
" [1.0000e+00, 3.4744e-20],\n",
" [1.0000e+00, 6.5306e-32]])}}"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"config = load_json(os.path.join(SRC_DIR, 'src/experiment_run_96_alpha/model/model_hyperparameters.json'))\n",
"\n",
"# Preprocess the sample rows using the helper function `to_inference_module_input_from_dataframe`\n",
"sample_df = df.sample(5)\n",
"sample_input = to_inference_module_input_from_dataframe(sample_df, config)\n",
"\n",
"# Call the inference module\n",
"inference_module(sample_input)"
]
},
{
"cell_type": "markdown",
"id": "cfee12e5-cd06-4c52-808f-c6c971491718",
"metadata": {},
"source": [
"**Step 3 (Optional): Performing batch inference on the full dataset**\n",
"\n",
"The torchscript module can be deployed in any environment with minimal dependencies. Below is a quick example of how one might perform inference across the entire dataset and compute metrics from the resulting output."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "a75163c3-16f0-4536-bfb6-748f197cccdc",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 371/371 [00:04<00:00, 76.14it/s]\n"
]
}
],
"source": [
"# Sample batch processing\n",
"num_samples = len(df)\n",
"batch_size = 32\n",
"\n",
"all_batch_idxs = np.array_split(np.arange(num_samples), num_samples // batch_size)\n",
"batch_labels = []\n",
"batch_predictions = []\n",
"batch_probabilities = []\n",
"for batch_idxs in tqdm(all_batch_idxs):\n",
" batch_df = df.iloc[batch_idxs]\n",
" batch_labels.append(batch_df['label'].values.tolist())\n",
" \n",
" batch_input = to_inference_module_input_from_dataframe(batch_df, config)\n",
" outputs = inference_module(batch_input)\n",
" batch_predictions.append(outputs['label']['predictions'])\n",
" batch_probabilities.append(outputs['label']['probabilities'][:, 1])\n",
"\n",
"batch_labels = np.concatenate(batch_labels)\n",
"batch_predictions = np.concatenate(batch_predictions)\n",
"batch_probabilities = np.concatenate(batch_probabilities)\n",
"\n",
"# Metrics\n",
"# print(accuracy_score(batch_labels, batch_predictions))\n",
"# print(roc_auc_score(batch_labels, batch_probabilities))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment