Last active
March 22, 2022 08:58
-
-
Save philschmid/5f877444ce2f456b36907bbb785a21f1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import boto3 | |
import os | |
import json | |
from locust import User, task, between | |
# How to use | |
# 1. install locust & boto3 | |
# pip install locust boto3 | |
# 2. run benchmark via cli | |
# with UI | |
# Since we are using a custom client for the request we need to define the "Host" as -. | |
# ENDPOINT_NAME="distilbert-base-uncased-distilled-squad-6493832c-767d-4cdb-a9a" locust -f locust_benchmark_sm.py | |
# | |
# headless | |
# --users Number of concurrent Locust users | |
# --spawn-rate The rate per second in which users are spawned until num users | |
# --run-time duration of test | |
# ENDPOINT_NAME="distilbert-base-uncased-distilled-squad-6493832c-767d-4cdb-a9a" locust -f locust_benchmark_sm.py \ | |
# --users 60 \ | |
# --spawn-rate 1 \ | |
# --run-time 360s \ | |
# --headless | |
# if you want to use local aws profiles | |
os.environ["AWS_PROFILE"] = "hf-sm" | |
# define region | |
os.environ["AWS_DEFAULT_REGION"] = "us-east-1" | |
client = boto3.client("sagemaker-runtime") | |
ENDPOINT_NAME = os.environ.get("ENDPOINT_NAME", None) | |
if ENDPOINT_NAME is None: | |
raise EnvironmentError("you have to define ENV ENDPOINT_NAME to run your benchmark") | |
# INPUT = { | |
# "inputs": "Jeff: Can I train a Hugging Face Transformers model on Amazon SageMaker? Philipp: Sure you can use the new Hugging Face Deep Learning Container. Jeff: ok. Jeff: and how can I get started? Jeff: where can I find documentation? Philipp: ok, ok you can find everything here. https://huggingface.co/blog/the-partnership-amazon-sagemaker-and-hugging-face. What a nice blog" | |
# } | |
# Input for QA | |
INPUT = { | |
"inputs": { | |
"context": "My Name is Philipp and I live in Nuremberg. This model is used with sagemaker for infernece.", | |
"question": "Where lives Philipp?", | |
} | |
} | |
class SageMakerClient: | |
_locust_environment = None | |
def __init__(self): | |
super().__init__() | |
self.client = boto3.client("sagemaker-runtime") | |
@staticmethod | |
def total_time(start_time) -> float: | |
return int((time.time() - start_time) * 1000) | |
def send(self, name: str, payload: dict, endpoint_name: str): | |
start_time = time.time() | |
try: | |
response = client.invoke_endpoint( | |
EndpointName=endpoint_name, | |
Body=json.dumps(INPUT), | |
ContentType="application/json", | |
Accept="application/json", | |
) | |
except Exception as e: | |
self._locust_environment.events.request_failure.fire( | |
request_type="execute", name=name, response_time=self.total_time(start_time), exception=e, response_length=0 | |
) | |
self._locust_environment.events.request_success.fire( | |
request_type="execute", name=name, response_time=self.total_time(start_time), response_length=0, | |
) | |
class SageMakerUser(User): | |
abstract = True | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.client = SageMakerClient() | |
self.client._locust_environment = self.environment | |
class SimpleSendRequest(SageMakerUser): | |
wait_time = between(0.1, 1) | |
@task | |
def send_request(self): | |
payload = INPUT | |
self.client.send("send_endpoint", payload, ENDPOINT_NAME) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment