Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jeffvestal/551cb07928dd01adec060521e21c3612 to your computer and use it in GitHub Desktop.
Save jeffvestal/551cb07928dd01adec060521e21c3612 to your computer and use it in GitHub Desktop.
Elastic - NLP - Load HuggingFace Model with Zero Shot Example
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": " Elastic - NLP - Load HuggingFace Model with Zero Shot Example",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyOroRibhoorNRL0/+AfddoE",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/jeffvestal/551cb07928dd01adec060521e21c3612/-elastic-nlp-load-huggingface-model-with-zero-shot-example.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"Use this code to load a NLP model from Hugging Face for use inside Elastic's elasticsearch. \n",
"\n",
"You can set up a [free trial elasticsearch Deployment in Elastic Cloud](https://cloud.elastic.co/registration) and run the below code in [Google's Colab](https://colab.research.google.com) for free.\n",
"\n",
"Requires Elastic version 8.0+ with a platinum or enterprise license (or trial license)\n",
"\n",
"Example here is loading a [Zero Shot model](https://huggingface.co/typeform/distilbert-base-uncased-mnli)\n",
"\n",
"[Elastic NLP Model Support Docs](https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-model-ref.html) \n",
"\n",
"[Code summarized from the eland docs](https://github.com/elastic/eland)\n",
"\n",
"Disclaimer: presented as is with no guarantee."
],
"metadata": {
"id": "6xoLDtS_6Df1"
}
},
{
"cell_type": "markdown",
"source": [
"# Install eland and *elasticsearch*"
],
"metadata": {
"id": "Ly1f1P-l9ri8"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rUedSzQW9FIF"
},
"outputs": [],
"source": [
"pip install eland"
]
},
{
"cell_type": "code",
"source": [
"pip install elasticsearch"
],
"metadata": {
"id": "NK3Wx1I199yB"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"pip install transformers"
],
"metadata": {
"id": "cEfiiFXakzdP"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"pip install sentence_transformers"
],
"metadata": {
"id": "I20mDmJboKZw"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"pip install torch==1.11"
],
"metadata": {
"id": "uqcpWrbkBEB9"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import getpass\n",
"from pathlib import Path\n",
"from eland.ml.pytorch import PyTorchModel\n",
"from eland.ml.pytorch.transformers import TransformerModel\n",
"from elasticsearch import Elasticsearch\n",
"from elasticsearch.client import MlClient\n"
],
"metadata": {
"id": "-dqhRCBUe1U-"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Set Connection and Auth"
],
"metadata": {
"id": "r7nMIbHke37Q"
}
},
{
"cell_type": "code",
"source": [
"es_url = getpass.getpass('Enter elasticsearch endpoint: ') # endpoint https://<esurl>:<port>\n",
"es_user = getpass.getpass('Enter username: ') # username\n",
"es_pass = getpass.getpass('Enter password: ') # password"
],
"metadata": {
"id": "SSGgYHome69o"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Connect to Elastic and Load a Hugging Face Model"
],
"metadata": {
"id": "jL4VDnVp96lf"
}
},
{
"cell_type": "code",
"source": [
"es = Elasticsearch(es_url, basic_auth=(es_user,es_pass))"
],
"metadata": {
"id": "I8mVJkKmetXo"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"[Supported `task_type` values](https://github.com/elastic/eland/blob/15a300728876022b206161d71055c67b500a0192/eland/ml/pytorch/transformers.py#*L41*)"
],
"metadata": {
"id": "QmZ1fkwYM5er"
}
},
{
"cell_type": "code",
"source": [
"# Download a Hugging Face Zero Shot model directly from the model hub\n",
"\n",
"# https://huggingface.co/typeform/distilbert-base-uncased-mnli\n",
"#tm = TransformerModel(\"sentence-transformers/all-MiniLM-L12-v2\", \"text_embedding\")\n",
"tm = TransformerModel(\"distilbert-base-cased-distilled-squad\", \"question_answering\")"
],
"metadata": {
"id": "zPV3oFsKiYFL"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Export the model in a TorchScrpt representation which Elasticsearch uses\n",
"tmp_path = \"models\"\n",
"Path(tmp_path).mkdir(parents=True, exist_ok=True)\n",
"# model_path, config_path, vocab_path = tm.save(tmp_path) #pre 8.2.0\n",
"model_path, config, vocab_path = tm.save(tmp_path)"
],
"metadata": {
"id": "GsSpvvP-nbCK"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Import model into Elasticsearch\n",
"ptm = PyTorchModel(es, tm.elasticsearch_model_id())\n",
"# ptm.import_model(model_path, config_path, vocab_path) # pre 8.2.0\n",
"ptm.import_model(model_path=model_path, config_path=None, vocab_path=vocab_path, config=config) "
],
"metadata": {
"id": "Z4QD71Apnj4j"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Deploy the model"
],
"metadata": {
"id": "oMGw3sk-pbaN"
}
},
{
"cell_type": "code",
"source": [
"# List models in elasticsearch\n",
"m = MlClient.get_trained_models(es, )\n",
"m.body"
],
"metadata": {
"id": "b4Wv8EJvpfZI"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Deploy the model\n",
"\n",
"# Model is listed -> 'model_id': 'typeform__distilbert-base-uncased-mnli'\n",
"model_id='distilbert-base-cased-distilled-squad'\n",
"\n",
"# start trained model deployment\n",
"s = MlClient.start_trained_model_deployment(es, model_id=model_id)\n",
"s.body\n",
"\n",
"# You can see model state in Kibana -> Machine Learning -> Model Management -> Trained Models"
],
"metadata": {
"id": "w5muJ1rLqvUW"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Zero Shot Time!"
],
"metadata": {
"id": "6Hu2n4bmGYkG"
}
},
{
"cell_type": "code",
"source": [
"# future reference do not use yet\n",
"#z = MlClient.infer_trained_model_deployment(es, model_id =model_id, docs=docs, )"
],
"metadata": {
"id": "ZsWg7XPSGbiu"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Using requests until MlClient.infer_trained_model_deployment is updated to accept inference extra configs\n",
"import requests\n",
"from requests.auth import HTTPBasicAuth\n",
"import urllib.parse\n",
"\n",
"endpoint = '_ml/trained_models/%s/deployment/_infer' % model_id\n",
"url = urllib.parse.urljoin(es_url, endpoint)\n",
"\n",
"body = {\n",
" \"docs\": [\n",
" {\n",
" \"text_field\": \"Last week I upgraded my iOS version and ever since then my phone has been overheating whenever I use your app.\"\n",
" }\n",
" ],\n",
" \"inference_config\": {\n",
" \"zero_shot_classification\": {\n",
" \"labels\": [\n",
" \"mobile\",\n",
" \"website\",\n",
" \"billing\",\n",
" \"account access\"\n",
" ],\n",
" \"multi_label\": True\n",
" }\n",
" }\n",
"}\n",
"\n",
"resp = requests.post(url, auth=HTTPBasicAuth(es_user, es_pass), json=body)\n",
"r = resp.json()\n",
"print('Predicted value is: %s with a probability of %0.2f%%' % (r['predicted_value'], r['prediction_probability'] * 100))\n",
"print('=-=-=-=')\n",
"print('Full Probability output:')\n",
"for c in r['top_classes']:\n",
" print ('%s probability of %0.5f%%' % (c['class_name'], c['class_probability'] * 100))"
],
"metadata": {
"id": "tf9c-XkrQTM3"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Just to see the full doc\n",
"resp.json()"
],
"metadata": {
"id": "f3JRG4SeaESo"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment