Skip to content

Instantly share code, notes, and snippets.

@ngoodger
Created January 12, 2020 11:48
Show Gist options
  • Save ngoodger/835efb5eb399f54af26749155da59d46 to your computer and use it in GitHub Desktop.
Save ngoodger/835efb5eb399f54af26749155da59d46 to your computer and use it in GitHub Desktop.
test_request.py
import numpy as np
import requests
import time
import pyarrow
import torchvision
import torch
MAX_BATCH_SIZE = 1
TEST_CORRECT_OUTPUT = False
if TEST_CORRECT_OUTPUT:
MODEL = torchvision.models.resnet18(pretrained=True).eval()
x = np.random.random((4, 3, 256, 256)).astype(np.float32)
serialized_data = pyarrow.serialize(x).to_buffer()
latencies = []
with requests.Session() as session:
start_time = time.time()
for i in range(100):
start_latency_time = time.time()
result = session.post('http://localhost:8122/', data=serialized_data)
end_latency_time = time.time()
latencies.append(end_latency_time - start_latency_time)
if TEST_CORRECT_OUTPUT:
with torch.no_grad():
y = MODEL(torch.from_numpy(x))
if x.shape[0] == 1 : y = y.unsqueeze(0)
assert(np.allclose(y, pyarrow.deserialize(result.content), atol=1e-4)), "Output is incorrect"
end_time = time.time()
print(f"max latency: {np.max(latencies)}")
print(f"min latency: {np.min(latencies)}")
print(f"mean latency: {np.mean(latencies)}")
print(f"total time: {end_time - start_time}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment