Skip to content

Instantly share code, notes, and snippets.

Created November 12, 2020 17:12
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
What would you like to do?
Google Cloud Run Distilbert example
import os
import json
from flask import Flask, request, Response
import torch
from transformers import (
DistilBertTokenizer, DistilBertForQuestionAnswering,
class Model(object):
def __init__(self, tokenizer, model):
self.tokenizer = tokenizer
self.model = model
def encode(self, question, context):
encoded = self.tokenizer.encode_plus(question, context)
return encoded["input_ids"], encoded["attention_mask"]
def decode(self, token):
answer_tokens = self.tokenizer.convert_ids_to_tokens(token , skip_special_tokens=True)
return self.tokenizer.convert_tokens_to_string(answer_tokens)
def predict(self, input):
question, context = input['question'], input['context']
input_ids, attention_mask = self.encode(question, context)
start_scores, end_scores = self.model(torch.tensor([input_ids]), attention_mask=torch.tensor([attention_mask]))
ans_tokens = input_ids[torch.argmax(start_scores) : torch.argmax(end_scores)+1]
answer = self.decode(ans_tokens)
return answer
# DistilBERT
model = Model(
tokenizer=DistilBertTokenizer.from_pretrained('./model', return_token_type_ids=True),
app = Flask(__name__)
def hello_world():
target = os.environ.get('TARGET', 'World')
return 'Hello {}!\n'.format(target)
@app.route('/predict', methods=['POST'])
def predict():
sentiment = request.get_json()
return Response(json.dumps(model.predict(sentiment)), mimetype='application/json')
if __name__ == "__main__":,host='',port=int(os.environ.get('PORT', 8080)))
import torch
from transformers import (
DistilBertTokenizer, DistilBertForQuestionAnswering,
DistilBertTokenizer.from_pretrained('distilbert-base-uncased', return_token_type_ids=True).save_pretrained('./model')
# Use the official lightweight Python image.
FROM python:3.7-slim
# Allow statements and log messages to immediately appear in the Knative logs
# Copy local code to the container image.
# Install production dependencies.
RUN pip install flask gunicorn transformers
RUN pip install torch==1.6.0+cpu -f
# The bootstrap script will download the model file into the container image
COPY ./ .
RUN python
# Now copy the app
COPY ./ .
# Run the web service on container startup. Here we use gunicorn
CMD exec gunicorn --bind :$PORT --workers 1 --threads 1 --timeout 0 app:app
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment