Skip to content

Instantly share code, notes, and snippets.

@bgweber
Created February 17, 2020 18:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bgweber/cef19bf1931139013af24bb32d67ae26 to your computer and use it in GitHub Desktop.
Save bgweber/cef19bf1931139013af24bb32d67ae26 to your computer and use it in GitHub Desktop.
from google.cloud import monitoring_v3
from google.oauth2 import service_account
from google.cloud import logging
import socket
import random
import time
import pandas as pd
from sklearn.linear_model import LogisticRegression
import flask
from multiprocessing import Value
import threading
# create a unique host name for the pod
host = socket.gethostbyname(socket.gethostname()) + " - " + str(random.randint(0, 1000000))
# load GCP credentials and set up the Stackdriver monitor connection
credentials = service_account.Credentials.from_service_account_file('serving.json')
client = monitoring_v3.MetricServiceClient(credentials = credentials)
project_name = client.project_path('serving-268422')
# set up the Stackdriver logging connection
logging_client = logging.Client(project = 'serving-268422', credentials = credentials)
logger = logging_client.logger('model_service')
logger.log_text("(" + host + ") Launching model service")
# train a scikit-learn model
df = pd.read_csv("https://github.com/bgweber/Twitch/raw/master/Recommendations/games-expand.csv")
model = LogisticRegression()
model.fit(df.drop(['label'], axis=1), df['label'])
# set up the app and a request tracker
counter = Value('i', 0)
app = flask.Flask(__name__)
# define a function for writing metrics to Stackdriver
def write_metric_value(value):
series = monitoring_v3.types.TimeSeries()
series.metric.type = 'custom.googleapis.com/serving/requests'
series.metric.labels['ip'] = host
point = series.points.add()
point.value.double_value = value
now = time.time()
point.interval.end_time.seconds = int(now)
client.create_time_series(project_name, [series])
# set up a callback for recording requests per minute to Stackdriver
def log_requests():
threading.Timer(60.0, log_requests).start()
requests = 0
with counter.get_lock():
requests = counter.value
counter.value = 0
print("writing value: " + str(requests))
write_metric_value(requests)
# initiate the request per minute tracking
log_requests()
# define the model endpoint
@app.route("/", methods=["GET","POST"])
def predict():
try :
# update the number of requests
with counter.get_lock():
counter.value += 1
data = {"success": False}
# check for passed in parameters
params = flask.request.json
if params is None:
params = flask.request.args
# get a model prediction
if "G1" in params.keys():
new_row = { "G1": params.get("G1"), "G2": params.get("G2"),
"G3": params.get("G3"), "G4": params.get("G4"),
"G5": params.get("G5"), "G6": params.get("G6"),
"G7": params.get("G7"), "G8": params.get("G8"),
"G9": params.get("G9"), "G10": params.get("G10") }
new_x = pd.DataFrame.from_dict(new_row, orient = "index").transpose()
data["response"] = str(model.predict_proba(new_x)[0][1])
data["success"] = True
return flask.jsonify(data)
except:
# log any invalid requests
logger.log_text("(" + host + ") Error servicing request: " + str(flask.request) + " " + str(params))
flask.abort(400)
# let gunicorn manage the ports to use
if __name__ == '__main__':
app.run(host='0.0.0.0')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment