Skip to content

Instantly share code, notes, and snippets.

@izzymiller
Created April 10, 2023 23:59
Show Gist options
  • Select an option

  • Save izzymiller/2ea987b90e6c96a005cb9026b9243c8e to your computer and use it in GitHub Desktop.

Select an option

Save izzymiller/2ea987b90e6c96a005cb9026b9243c8e to your computer and use it in GitHub Desktop.
import modal
import os
#create a shared volume to store weights
volume = modal.SharedVolume().persist("robo-boys-vol")
#create a modal "stub" to handle config for functions
stub = modal.Stub(
"robo-boys-predict",
image=modal.Image.debian_slim().pip_install("numpy",
"rouge-score",
"fire",
"torch",
"sentencepiece",
"firebase-admin",
"tokenizers").apt_install('git').run_commands('pip install git+https://github.com/huggingface/transformers')
)
#This is a one time function to download my weights. I'd probably use the Modal CLI for this next time.
@stub.function(shared_volumes={"/models": volume},secrets=[modal.Secret.from_name("robo-boys-secrets")])
def download_model():
print('downloading model from aws')
os.system(f"ls /models")
os.system("ls")
os.system('aws configure list')
os.system(f"aws s3 cp --recursive s3://path/to/your/checkpoint /models/model")
print('downloaded model from aws')
class MessagePrediction:
def __enter__(self):
import transformers
import firebase_admin
from firebase_admin import credentials
from firebase_admin import firestore
import json
service_account_info = json.loads(os.environ["SERVICE_ACCOUNT_JSON"])
cred = credentials.Certificate(service_account_info)
app = firebase_admin.initialize_app(cred)
# Create a Firestore client
self.db = firestore.client()
m_inter = transformers.LlamaForCausalLM.from_pretrained("/models/model")
self.tokenizer = transformers.AutoTokenizer.from_pretrained("/models/model")
m_inter = m_inter.half()
self.model = m_inter.to("cuda")
@stub.function(gpu=modal.gpu.A10G(count=1), shared_volumes={"/models": volume},secrets=[modal.Secret.from_name("firebase-svc")],container_idle_timeout=1200,timeout=500,concurrency_limit=1)
def create_conversation(self,init_context: str,wake: bool):
import random
import traceback
if wake: # just a way to wake up this function!
return
ctx = ''
# conditionally get 'background' context on chat if desired, helpful to keep conversations going across multiple prompts.
background = self.get_firestore_context()
if len(background) > 0:
ctx = background + '\n' + init_context
else:
ctx = init_context
print(ctx)
counter = 0
backup_counter = 0
most_recent_sender = init_context.split(":")[0]
most_recent_message = init_context.split(":")[1]
# quick and dirty loop to generate an entire conversation. These probabilities are based off the actual distribution of messages in the chat archive.
while counter <= 12 and backup_counter <= 40:
try:
backup_counter += 1 #prevent infinite loops due to reaction chains
characters = ['Wyatt','Kiebs','Izzy','Luke','Harvey','Henry']
character_probabilities = [0.15,0.4,0.6,0.1,0.3,0.6]
most_recent_index = characters.index(most_recent_sender)
if counter == 0:
character_probabilities[most_recent_index] = 0
else:
character_probabilities[most_recent_index] += .2
most_recent_referenced = ''
if 'adam' in most_recent_message or 'Adam' in most_recent_message or 'kiebs' in most_recent_message or 'Kiebs' in most_recent_message:
most_recent_referenced = 'Kiebs'
elif 'wyatt' in most_recent_message or 'Wyatt' in most_recent_message:
most_recent_referenced = 'Wyatt'
elif 'izzy' in most_recent_message or 'Izzy' in most_recent_message or 'iz' in most_recent_message:
most_recent_referenced = 'Izzy'
elif 'luke' in most_recent_message or 'Luke' in most_recent_message:
most_recent_referenced = 'Luke'
elif 'harv' in most_recent_message or 'Harv' in most_recent_message:
most_recent_referenced = 'Harvey'
elif 'hen' in most_recent_message or 'Hen' in most_recent_message or 'Hank' in most_recent_message:
most_recent_referenced = 'Henry'
if len(most_recent_referenced) > 0:
referenced_index = characters.index(most_recent_referenced)
character_probabilities[referenced_index] += .7
character = random.choices(characters,character_probabilities)[0]
res = self.predict(context=ctx,character=character)
temp = ''
for i in res.split("###")[-2:]:
temp += i
if len(temp.split("Response:")) < 2:
print(temp)
print('split: ',temp.split("Response:"))
print('no completion generated, skipping')
continue
temp = temp.split("Response:")[1]
temp = temp.replace("</s>","")
if u'\uFFFC' in temp: #this is the character used to represent images in the model, unnecessary if you cleaned them out prior.
continue
if 'https://' in temp:
print('just link, not incrementing counter')
continue
if 'Loved' in temp or 'Laughed' in temp or 'Disliked' in temp or 'Emphasized' in temp or 'Liked' in temp:
print('suppressing reaction')
continue
m = self.dispatch_msg_to_firestore(temp,character)
text = f"{character}: {temp}"
ctx += '\n'
ctx += text
ctx = ctx.replace(' "Input:\n','')
counter+= 1
except Exception:
print('failed message',traceback.format_exc())
break
return ctx
#exists only to wake the container
@stub.function(gpu=modal.gpu.A10G(count=1), shared_volumes={"/models": volume},secrets=[modal.Secret.from_name("firebase-svc")],container_idle_timeout=1200,timeout=500,concurrency_limit=1)
def wake():
print('waking up')
def predict(self, context: str,character: str,temp=.8):
#unnecessary old code from alpaca, just haven't cleaned it yet.
PROMPT_DICT = {
"prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
),
"prompt_no_input": (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"
),
}
PROMPT = PROMPT_DICT['prompt_input']
characters = ['Wyatt','Kiebs','Izzy','Luke','Harvey','Henry']
others = list(filter(lambda x: x!=character,characters))
prompt = f"Your name is {character}. You are in a group chat with 5 of your best friends: {', '.join(others)}. You all went to college together, and talk to each other with no filter. You are encouraged to curse, say generally inappropriate and hilarious things, or be very rude. Everything is in good fun, so remember to joke and laugh, and be funny, but also answer all questions clearly. You will be presented with the most recent messages in the group chat. Write a response to the conversation as {character}."
format_prompt = PROMPT.format_map({'instruction': prompt,'input':context})
tokenized_text = self.tokenizer(format_prompt, return_tensors="pt", padding="longest", max_length=self.tokenizer.model_max_length, truncation=True)
full_completion = self.model.generate(inputs=tokenized_text["input_ids"].to("cuda"),
attention_mask=tokenized_text["attention_mask"].to("cuda"),
temperature=.75,
top_p=0.85,
top_k=80,
do_sample=True,
num_beams=3,
max_new_tokens=600,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
repetition_penalty=1)
decoded_text = self.tokenizer.decode(full_completion[0])
return decoded_text
def dispatch_msg_to_firestore(self,message,sender):
from datetime import datetime,timezone
import time
# I delay to make the conversation more realistic on the front-end. Could save a ton of money probably by doing this delay on the frontend instead!
time.sleep(0.25)
senders = {
'Henry': {
'uid': 'fake-henry',
'photo': 'https://i.imgur.com/wdXWHz2.jpg',
'email': 'fake@email.com',
'displayName': 'Henry'
},
'Harvey': {
'uid': 'fake-harvey',
'photo': 'https://i.imgur.com/sU8Codw.jpg',
'email': 'fake@email.com',
'displayName': 'Harvey'
},
'Luke': {
'uid': 'fake-luke',
'photo': 'https://i.imgur.com/U645ciG.jpg',
'email': 'fake@email.com',
'displayName': 'Luke'
},
'Izzy': {
'uid': 'fake-izzy',
'photo': 'https://i.imgur.com/wUGEnVb.jpg',
'email': 'fake@email.com',
'displayName': 'Izzy'
},
'Kiebs': {
'uid': 'fake-kiebs',
'photo': 'https://i.imgur.com/ESoUipA.png',
'email': 'fake@email.com',
'displayName': 'Kiebs'
},
'Wyatt': {
'uid': 'fake-wyatt',
'photo': 'https://i.imgur.com/9yPKaac.jpg',
'email': 'fake@email.com',
'displayName': 'Wyatt'
}
}
sender = senders[sender]
chat_doc_ref = self.db.collection('chats').document('<chatdb>')
chat_messages_ref = chat_doc_ref.collection('messages')
create_time, doc_ref = chat_messages_ref.add({
'timestamp': datetime.now(timezone.utc),
'message': message,
'uid': sender['uid'],
'photo': sender['photo'],
'email': sender['email'],
'displayName': sender['displayName'],
})
return create_time
def get_firestore_context(self):
from firebase_admin import firestore
from datetime import datetime, timedelta,timezone
chat_doc_ref = self.db.collection('chats').document('<chatdb>')
chat_messages_ref = chat_doc_ref.collection('messages')
most_recent_message = chat_messages_ref.order_by('timestamp', direction=firestore.Query.DESCENDING).limit(1).get()[0]
message_timestamp = most_recent_message.get('timestamp')
current_time = datetime.now(timezone.utc)
time_diff = current_time - message_timestamp
if time_diff <= timedelta(minutes=4):
messages = chat_messages_ref.order_by('timestamp', direction=firestore.Query.DESCENDING).limit(10).get()
ctx = ''
prev = ''
for i in messages:
raw = i.to_dict()
if prev == raw['message']:
return ''
msg = f"{raw['displayName']} : {raw['message']}"
ctx += msg
ctx += '\n'
prev = raw['message']
return ctx
else:
return ''
# just for testing
@stub.webhook
def get_completion(context: str):
from fastapi.responses import HTMLResponse
convo = MessagePrediction().create_conversation.call(init_context=context, wake=False)
to_render = convo.replace("\n", "<br />")
return HTMLResponse(to_render)
@stub.webhook(label="alive", image=modal.Image.debian_slim())
def check_alive():
print('Checking status of GPU container')
status = MessagePrediction().create_conversation.get_current_stats()
return status
@stub.webhook(label="wake")
def wake():
MessagePrediction().create_conversation.spawn(init_context='wake', wake=True)
print('waking up container')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment