Created
November 12, 2020 17:12
-
-
Save CaptEmulation/6ce59a31cebc3014b42e2a3faaafe1d9 to your computer and use it in GitHub Desktop.
Google Cloud Run Distilbert example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import json | |
from flask import Flask, request, Response | |
import torch | |
from transformers import ( | |
DistilBertTokenizer, DistilBertForQuestionAnswering, | |
) | |
class Model(object): | |
def __init__(self, tokenizer, model): | |
self.tokenizer = tokenizer | |
self.model = model | |
def encode(self, question, context): | |
encoded = self.tokenizer.encode_plus(question, context) | |
return encoded["input_ids"], encoded["attention_mask"] | |
def decode(self, token): | |
answer_tokens = self.tokenizer.convert_ids_to_tokens(token , skip_special_tokens=True) | |
return self.tokenizer.convert_tokens_to_string(answer_tokens) | |
def predict(self, input): | |
question, context = input['question'], input['context'] | |
input_ids, attention_mask = self.encode(question, context) | |
start_scores, end_scores = self.model(torch.tensor([input_ids]), attention_mask=torch.tensor([attention_mask])) | |
ans_tokens = input_ids[torch.argmax(start_scores) : torch.argmax(end_scores)+1] | |
answer = self.decode(ans_tokens) | |
return answer | |
# DistilBERT | |
model = Model( | |
tokenizer=DistilBertTokenizer.from_pretrained('./model', return_token_type_ids=True), | |
model=DistilBertForQuestionAnswering.from_pretrained('./model'), | |
) | |
app = Flask(__name__) | |
@app.route('/') | |
def hello_world(): | |
target = os.environ.get('TARGET', 'World') | |
return 'Hello {}!\n'.format(target) | |
@app.route('/predict', methods=['POST']) | |
def predict(): | |
sentiment = request.get_json() | |
return Response(json.dumps(model.predict(sentiment)), mimetype='application/json') | |
if __name__ == "__main__": | |
app.run(debug=True,host='0.0.0.0',port=int(os.environ.get('PORT', 8080))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from transformers import ( | |
DistilBertTokenizer, DistilBertForQuestionAnswering, | |
) | |
DistilBertTokenizer.from_pretrained('distilbert-base-uncased', return_token_type_ids=True).save_pretrained('./model') | |
DistilBertForQuestionAnswering.from_pretrained('distilbert-base-uncased-distilled-squad').save_pretrained('./model') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Use the official lightweight Python image. | |
# https://hub.docker.com/_/python | |
FROM python:3.7-slim | |
# Allow statements and log messages to immediately appear in the Knative logs | |
ENV PYTHONUNBUFFERED True | |
# Copy local code to the container image. | |
ENV APP_HOME /app | |
WORKDIR $APP_HOME | |
# Install production dependencies. | |
RUN pip install flask gunicorn transformers | |
RUN pip install torch==1.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html | |
# The bootstrap script will download the model file into the container image | |
COPY ./bootstrap.py . | |
RUN python bootstrap.py | |
# Now copy the app | |
COPY ./app.py . | |
# Run the web service on container startup. Here we use gunicorn | |
CMD exec gunicorn --bind :$PORT --workers 1 --threads 1 --timeout 0 app:app |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment