Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Google Cloud Run Distilbert example
import os
import json
from flask import Flask, request, Response
import torch
from transformers import (
DistilBertTokenizer, DistilBertForQuestionAnswering,
)
class Model(object):
def __init__(self, tokenizer, model):
self.tokenizer = tokenizer
self.model = model
def encode(self, question, context):
encoded = self.tokenizer.encode_plus(question, context)
return encoded["input_ids"], encoded["attention_mask"]
def decode(self, token):
answer_tokens = self.tokenizer.convert_ids_to_tokens(token , skip_special_tokens=True)
return self.tokenizer.convert_tokens_to_string(answer_tokens)
def predict(self, input):
question, context = input['question'], input['context']
input_ids, attention_mask = self.encode(question, context)
start_scores, end_scores = self.model(torch.tensor([input_ids]), attention_mask=torch.tensor([attention_mask]))
ans_tokens = input_ids[torch.argmax(start_scores) : torch.argmax(end_scores)+1]
answer = self.decode(ans_tokens)
return answer
# DistilBERT
model = Model(
tokenizer=DistilBertTokenizer.from_pretrained('./model', return_token_type_ids=True),
model=DistilBertForQuestionAnswering.from_pretrained('./model'),
)
app = Flask(__name__)
@app.route('/')
def hello_world():
target = os.environ.get('TARGET', 'World')
return 'Hello {}!\n'.format(target)
@app.route('/predict', methods=['POST'])
def predict():
sentiment = request.get_json()
return Response(json.dumps(model.predict(sentiment)), mimetype='application/json')
if __name__ == "__main__":
app.run(debug=True,host='0.0.0.0',port=int(os.environ.get('PORT', 8080)))
import torch
from transformers import (
DistilBertTokenizer, DistilBertForQuestionAnswering,
)
DistilBertTokenizer.from_pretrained('distilbert-base-uncased', return_token_type_ids=True).save_pretrained('./model')
DistilBertForQuestionAnswering.from_pretrained('distilbert-base-uncased-distilled-squad').save_pretrained('./model')
# Use the official lightweight Python image.
# https://hub.docker.com/_/python
FROM python:3.7-slim
# Allow statements and log messages to immediately appear in the Knative logs
ENV PYTHONUNBUFFERED True
# Copy local code to the container image.
ENV APP_HOME /app
WORKDIR $APP_HOME
# Install production dependencies.
RUN pip install flask gunicorn transformers
RUN pip install torch==1.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
# The bootstrap script will download the model file into the container image
COPY ./bootstrap.py .
RUN python bootstrap.py
# Now copy the app
COPY ./app.py .
# Run the web service on container startup. Here we use gunicorn
CMD exec gunicorn --bind :$PORT --workers 1 --threads 1 --timeout 0 app:app
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.