Skip to content

Instantly share code, notes, and snippets.

@nforest
Created July 27, 2022 06:30
Show Gist options
  • Star 11 You must be signed in to star a gist
  • Fork 4 You must be signed in to fork a gist
  • Save nforest/d1432b917468f5ad24b83954c98e67b1 to your computer and use it in GitHub Desktop.
Save nforest/d1432b917468f5ad24b83954c98e67b1 to your computer and use it in GitHub Desktop.
apply your own model with copilot interface
# "github.copilot.advanced": {
# "debug.overrideEngine": "codeai",
# "debug.testOverrideProxyUrl": "http://9.135.120.183:5000",
# "debug.overrideProxyUrl": "http://9.135.120.183:5000"
# }
import json
import time
import random
import string
import torch
import transformers
from flask import Flask, request
class KeywordsStoppingCriteria(transformers.StoppingCriteria):
def __init__(self, tokenizer, keywords:list):
self.keywords = keywords
self.tokenizer = tokenizer
self.offset = max(len(tokenizer.encode(kw)) for kw in keywords)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
input_ids = input_ids[0][-self.offset:]
output = self.tokenizer.decode(input_ids, skip_special_tokens=True)
return any(kw in output for kw in self.keywords)
class CodeAI:
def __init__(self, pretrained, max_time=10, device='cuda'):
self.device = torch.device(device)
self.max_time = max_time
self.tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained)
self.model = transformers.AutoModelForCausalLM.from_pretrained(pretrained).to(self.device)
print("{} model loaded...".format(pretrained))
def trim_with_stopwords(self, output, stopwords):
text = self.tokenizer.decode(output)
for stop_word in stopwords:
text = text.split(stop_word)[0]
output = self.tokenizer.encode(text)
return output
def generate(self, data):
input_ids = self.tokenizer.encode(data['prompt'], return_tensors='pt').to(self.device)
stopping_criteria = transformers.StoppingCriteriaList([KeywordsStoppingCriteria(self.tokenizer, data['stop'])])
outputs = self.model.generate(
input_ids=input_ids,
do_sample=True,
top_p=data['top_p'],
max_new_tokens=data['max_tokens'],
num_beams=data['logprobs'],
num_return_sequences=data['n'],
stopping_criteria=stopping_criteria,
max_time=self.max_time,
output_scores=True,
return_dict_in_generate=True
)
return outputs
def create_choices(self, data, outputs):
choices = []
for index,output in enumerate(outputs.sequences):
finish_reason = 'stop' if output[-1] == self.model.config.eos_token_id else 'length'
output = output[:-1] if output[-1] == self.model.config.eos_token_id else output
output = output[-len(outputs.scores):]
output = self.trim_with_stopwords(output[:-1], data['stop'])
text = self.tokenizer.decode(output, skip_special_tokens=True)
tokens = list(map(lambda x: self.tokenizer.decode(x), output))
text_offset = []
prev_offset = len(data['prompt'])
for t in tokens:
text_offset.append(prev_offset)
prev_offset += len(t)
token_logprobs, top_logprobs = [], []
for i, token in enumerate(output):
beam_idx = outputs.beam_indices[index][i]
score = outputs.scores[i][beam_idx][token]
token_logprobs.append(score.type(torch.float16).item())
logprobs = outputs.scores[i][beam_idx].topk(k=data['logprobs'])
indices = list(map(lambda x: self.tokenizer.decode(x), logprobs.indices))
values = list(map(lambda x: x.type(torch.float16).item(), logprobs.values))
top_logprobs.append(dict(zip(indices, values)))
choice = {
'text': text,
'index': index,
'finish_reason': finish_reason,
'logprobs': {
'tokens': tokens,
'token_logprobs': token_logprobs,
'top_logprobs': top_logprobs,
'text_offset': text_offset
}
}
choices.append(choice)
return choices
def __call__(self, data):
outputs = self.generate(data)
choices = self.create_choices(data, outputs)
for choice in choices:
completion = json.dumps({
'id': 'cmpl-' + ''.join(random.choice(string.ascii_letters+string.digits) for _ in range(29)),
'model': 'codeai',
'created': int(time.time()),
'choices': [choice]
})
yield 'data: {}\n\n'.format(completion)
yield 'data: [DONE]\n\n'
codeai = CodeAI(pretrained="lvwerra/codeparrot-small")
# codeai = CodeAI(pretrained="EleutherAI/gpt-j-6B")
app = Flask(__name__)
@app.route("/v1/engines/codeai/completions", methods=["POST"])
def completions():
if request.method == "POST":
data = json.loads(request.data)
print(data)
return app.response_class(codeai(data), mimetype='text/event-stream')
if __name__ == '__main__':
app.run()
@TechnologyClassroom
Copy link

TechnologyClassroom commented Aug 3, 2022

This gist should be a full repo. This is incredible stuff!

I see you have no LICENSE header for this project. The default is copyright.

I would suggest releasing the code under the GPL-3.0-or-later or AGPL-3.0-or-later license so that others are encouraged to contribute changes back to your project.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment