This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from encoder import get_encoder | |
encoder = get_encoder() | |
def pre_inference(sample, signature, metadata): | |
context = encoder.encode(sample["text"]) | |
return {"context": [context]} | |
def post_inference(prediction, signature, metadata): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
- kind: deployment | |
name: text | |
- kind: api | |
name: generator | |
tensorflow: | |
model: s3://cortex-test-project/124M/124M/ | |
request_handler: handler.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import os | |
import json | |
import regex as re | |
from functools import lru_cache | |
import requests | |
import boto3 | |
@lru_cache() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# predictor.py | |
from transformers import GPT2Tokenizer, GPT2LMHeadModel | |
# download the pretrained DistilGPT2 model and set it to evaluation | |
model = GPT2LMHeadModel.from_pretrained("distilgpt2") | |
model.eval() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# predictor.py | |
def predict(sample, metadata): | |
indexed_tokens = tokenizer.encode(sample["text"]) | |
output = sample_sequence(model, metadata['num_words'], indexed_tokens, device=metadata['device']) | |
return tokenizer.decode( | |
output[0, 0:].tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# cortex.yaml | |
- kind: deployment | |
name: text | |
- kind: api | |
name: generator | |
predictor: | |
path: predictor.py | |
metadata: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from summarizer import Summarizer | |
class PythonPredictor: | |
def __init__(self, config): | |
self.model = Summarizer() | |
def predict(self, payload): | |
return self.model(payload["text"]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
transformers | |
spacy==2.1.3 | |
bert-extractive-summarizer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
- kind: deployment | |
name: text | |
- kind: api | |
name: summarizer | |
predictor: | |
type: python | |
path: predictor.py | |
compute: | |
mem: 4G |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pickle | |
from dvc import api | |
ctx = {} | |
def init(model_path, metadata): | |
ctx["model"] = pickle.loads(api.read(metadata["model_path"], metadata["dvc_repo"], mode="rb")) | |
ctx["pipeline"] = pickle.loads( | |
api.read(metadata["pipeline_path"], metadata["dvc_repo"], mode="rb") | |
) |