Created
February 17, 2020 18:24
-
-
Save bgweber/cef19bf1931139013af24bb32d67ae26 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from google.cloud import monitoring_v3 | |
from google.oauth2 import service_account | |
from google.cloud import logging | |
import socket | |
import random | |
import time | |
import pandas as pd | |
from sklearn.linear_model import LogisticRegression | |
import flask | |
from multiprocessing import Value | |
import threading | |
# create a unique host name for the pod | |
host = socket.gethostbyname(socket.gethostname()) + " - " + str(random.randint(0, 1000000)) | |
# load GCP credentials and set up the Stackdriver monitor connection | |
credentials = service_account.Credentials.from_service_account_file('serving.json') | |
client = monitoring_v3.MetricServiceClient(credentials = credentials) | |
project_name = client.project_path('serving-268422') | |
# set up the Stackdriver logging connection | |
logging_client = logging.Client(project = 'serving-268422', credentials = credentials) | |
logger = logging_client.logger('model_service') | |
logger.log_text("(" + host + ") Launching model service") | |
# train a scikit-learn model | |
df = pd.read_csv("https://github.com/bgweber/Twitch/raw/master/Recommendations/games-expand.csv") | |
model = LogisticRegression() | |
model.fit(df.drop(['label'], axis=1), df['label']) | |
# set up the app and a request tracker | |
counter = Value('i', 0) | |
app = flask.Flask(__name__) | |
# define a function for writing metrics to Stackdriver | |
def write_metric_value(value): | |
series = monitoring_v3.types.TimeSeries() | |
series.metric.type = 'custom.googleapis.com/serving/requests' | |
series.metric.labels['ip'] = host | |
point = series.points.add() | |
point.value.double_value = value | |
now = time.time() | |
point.interval.end_time.seconds = int(now) | |
client.create_time_series(project_name, [series]) | |
# set up a callback for recording requests per minute to Stackdriver | |
def log_requests(): | |
threading.Timer(60.0, log_requests).start() | |
requests = 0 | |
with counter.get_lock(): | |
requests = counter.value | |
counter.value = 0 | |
print("writing value: " + str(requests)) | |
write_metric_value(requests) | |
# initiate the request per minute tracking | |
log_requests() | |
# define the model endpoint | |
@app.route("/", methods=["GET","POST"]) | |
def predict(): | |
try : | |
# update the number of requests | |
with counter.get_lock(): | |
counter.value += 1 | |
data = {"success": False} | |
# check for passed in parameters | |
params = flask.request.json | |
if params is None: | |
params = flask.request.args | |
# get a model prediction | |
if "G1" in params.keys(): | |
new_row = { "G1": params.get("G1"), "G2": params.get("G2"), | |
"G3": params.get("G3"), "G4": params.get("G4"), | |
"G5": params.get("G5"), "G6": params.get("G6"), | |
"G7": params.get("G7"), "G8": params.get("G8"), | |
"G9": params.get("G9"), "G10": params.get("G10") } | |
new_x = pd.DataFrame.from_dict(new_row, orient = "index").transpose() | |
data["response"] = str(model.predict_proba(new_x)[0][1]) | |
data["success"] = True | |
return flask.jsonify(data) | |
except: | |
# log any invalid requests | |
logger.log_text("(" + host + ") Error servicing request: " + str(flask.request) + " " + str(params)) | |
flask.abort(400) | |
# let gunicorn manage the ports to use | |
if __name__ == '__main__': | |
app.run(host='0.0.0.0') | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment