Last active
April 16, 2020 15:48
-
-
Save mavvverick/e22012a8e0b42432865f69bf776fce64 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*-coding:utf-8-*- | |
import numpy as np | |
from PIL import Image | |
import sys | |
import json | |
import requests | |
import glob | |
_IMAGE_SIZE = 299 | |
SERVER_URL = 'http://localhost:8501/v1/models/nsfw:predict' | |
_LABEL_MAP = {0: 'drawings', 1: 'hentai', 2: 'neutral', 3: 'porn', 4: 'sexy'} | |
def standardize(img): | |
mean = np.mean(img) | |
std = np.std(img) | |
img = (img - mean) / std | |
return img | |
def load_image(folder_path): | |
files = [f for f in glob.glob(folder_path + "**/*.jpg", recursive=True)] | |
input_list = [] | |
for image_path in files: | |
img = Image.open(image_path) | |
img = img.resize((_IMAGE_SIZE, _IMAGE_SIZE)) | |
img.load() | |
data = np.asarray(img, dtype="float32") | |
data = standardize(data) | |
data = data.astype(np.float16, copy=False) | |
input_list.append(data.tolist()) | |
return input_list | |
def nsfw_predict(images_data_list): | |
# pay_load = json.dumps( | |
# {"inputs": [image_data.tolist(), image_data.tolist()]}) | |
pay_load = json.dumps({"inputs": images_data_list}) | |
response = requests.post(SERVER_URL, data=pay_load) | |
data = response.json() | |
predict_result_map = [] | |
if 'outputs' in data: | |
outputs = data['outputs'] | |
for output in outputs: | |
predict_result = { | |
_LABEL_MAP[0]: output[0], | |
_LABEL_MAP[1]: output[1], | |
_LABEL_MAP[2]: output[2], | |
_LABEL_MAP[3]: output[3], | |
_LABEL_MAP[4]: output[4] | |
} | |
predict_result_map.append(predict_result) | |
return predict_result_map | |
else: | |
return data | |
if __name__ == '__main__': | |
image_path = '' | |
args = sys.argv | |
if len(args) < 2: | |
print("usage: python serving_client.py <image_folder>") | |
image_path = args[1] | |
images_data_list = load_image(image_path) | |
predict = nsfw_predict(images_data_list) | |
print(predict) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment