Skip to content

Instantly share code, notes, and snippets.

What would you like to do?
Google Cloud Run Distilbert example
import os
import json
from flask import Flask, request, Response
import torch
from transformers import (
DistilBertTokenizer, DistilBertForQuestionAnswering,
class Model(object):
def __init__(self, tokenizer, model):
self.tokenizer = tokenizer
self.model = model
def encode(self, question, context):
encoded = self.tokenizer.encode_plus(question, context)
return encoded["input_ids"], encoded["attention_mask"]
def decode(self, token):
answer_tokens = self.tokenizer.convert_ids_to_tokens(token , skip_special_tokens=True)
return self.tokenizer.convert_tokens_to_string(answer_tokens)
def predict(self, input):
question, context = input['question'], input['context']
input_ids, attention_mask = self.encode(question, context)
start_scores, end_scores = self.model(torch.tensor([input_ids]), attention_mask=torch.tensor([attention_mask]))
ans_tokens = input_ids[torch.argmax(start_scores) : torch.argmax(end_scores)+1]
answer = self.decode(ans_tokens)
return answer
# DistilBERT
model = Model(
tokenizer=DistilBertTokenizer.from_pretrained('./model', return_token_type_ids=True),
app = Flask(__name__)
def hello_world():
target = os.environ.get('TARGET', 'World')
return 'Hello {}!\n'.format(target)
@app.route('/predict', methods=['POST'])
def predict():
sentiment = request.get_json()
return Response(json.dumps(model.predict(sentiment)), mimetype='application/json')
if __name__ == "__main__":,host='',port=int(os.environ.get('PORT', 8080)))
import torch
from transformers import (
DistilBertTokenizer, DistilBertForQuestionAnswering,
DistilBertTokenizer.from_pretrained('distilbert-base-uncased', return_token_type_ids=True).save_pretrained('./model')
# Use the official lightweight Python image.
FROM python:3.7-slim
# Allow statements and log messages to immediately appear in the Knative logs
# Copy local code to the container image.
# Install production dependencies.
RUN pip install flask gunicorn transformers
RUN pip install torch==1.6.0+cpu -f
# The bootstrap script will download the model file into the container image
COPY ./ .
RUN python
# Now copy the app
COPY ./ .
# Run the web service on container startup. Here we use gunicorn
CMD exec gunicorn --bind :$PORT --workers 1 --threads 1 --timeout 0 app:app
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment