Skip to content

Instantly share code, notes, and snippets.

@CaptEmulation
Created November 12, 2020 17:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save CaptEmulation/6ce59a31cebc3014b42e2a3faaafe1d9 to your computer and use it in GitHub Desktop.
Save CaptEmulation/6ce59a31cebc3014b42e2a3faaafe1d9 to your computer and use it in GitHub Desktop.
Google Cloud Run Distilbert example
import os
import json
from flask import Flask, request, Response
import torch
from transformers import (
DistilBertTokenizer, DistilBertForQuestionAnswering,
)
class Model(object):
def __init__(self, tokenizer, model):
self.tokenizer = tokenizer
self.model = model
def encode(self, question, context):
encoded = self.tokenizer.encode_plus(question, context)
return encoded["input_ids"], encoded["attention_mask"]
def decode(self, token):
answer_tokens = self.tokenizer.convert_ids_to_tokens(token , skip_special_tokens=True)
return self.tokenizer.convert_tokens_to_string(answer_tokens)
def predict(self, input):
question, context = input['question'], input['context']
input_ids, attention_mask = self.encode(question, context)
start_scores, end_scores = self.model(torch.tensor([input_ids]), attention_mask=torch.tensor([attention_mask]))
ans_tokens = input_ids[torch.argmax(start_scores) : torch.argmax(end_scores)+1]
answer = self.decode(ans_tokens)
return answer
# DistilBERT
model = Model(
tokenizer=DistilBertTokenizer.from_pretrained('./model', return_token_type_ids=True),
model=DistilBertForQuestionAnswering.from_pretrained('./model'),
)
app = Flask(__name__)
@app.route('/')
def hello_world():
target = os.environ.get('TARGET', 'World')
return 'Hello {}!\n'.format(target)
@app.route('/predict', methods=['POST'])
def predict():
sentiment = request.get_json()
return Response(json.dumps(model.predict(sentiment)), mimetype='application/json')
if __name__ == "__main__":
app.run(debug=True,host='0.0.0.0',port=int(os.environ.get('PORT', 8080)))
import torch
from transformers import (
DistilBertTokenizer, DistilBertForQuestionAnswering,
)
DistilBertTokenizer.from_pretrained('distilbert-base-uncased', return_token_type_ids=True).save_pretrained('./model')
DistilBertForQuestionAnswering.from_pretrained('distilbert-base-uncased-distilled-squad').save_pretrained('./model')
# Use the official lightweight Python image.
# https://hub.docker.com/_/python
FROM python:3.7-slim
# Allow statements and log messages to immediately appear in the Knative logs
ENV PYTHONUNBUFFERED True
# Copy local code to the container image.
ENV APP_HOME /app
WORKDIR $APP_HOME
# Install production dependencies.
RUN pip install flask gunicorn transformers
RUN pip install torch==1.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
# The bootstrap script will download the model file into the container image
COPY ./bootstrap.py .
RUN python bootstrap.py
# Now copy the app
COPY ./app.py .
# Run the web service on container startup. Here we use gunicorn
CMD exec gunicorn --bind :$PORT --workers 1 --threads 1 --timeout 0 app:app
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment