-
-
Save izzymiller/2ea987b90e6c96a005cb9026b9243c8e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import modal | |
| import os | |
| #create a shared volume to store weights | |
| volume = modal.SharedVolume().persist("robo-boys-vol") | |
| #create a modal "stub" to handle config for functions | |
| stub = modal.Stub( | |
| "robo-boys-predict", | |
| image=modal.Image.debian_slim().pip_install("numpy", | |
| "rouge-score", | |
| "fire", | |
| "torch", | |
| "sentencepiece", | |
| "firebase-admin", | |
| "tokenizers").apt_install('git').run_commands('pip install git+https://github.com/huggingface/transformers') | |
| ) | |
| #This is a one time function to download my weights. I'd probably use the Modal CLI for this next time. | |
| @stub.function(shared_volumes={"/models": volume},secrets=[modal.Secret.from_name("robo-boys-secrets")]) | |
| def download_model(): | |
| print('downloading model from aws') | |
| os.system(f"ls /models") | |
| os.system("ls") | |
| os.system('aws configure list') | |
| os.system(f"aws s3 cp --recursive s3://path/to/your/checkpoint /models/model") | |
| print('downloaded model from aws') | |
| class MessagePrediction: | |
| def __enter__(self): | |
| import transformers | |
| import firebase_admin | |
| from firebase_admin import credentials | |
| from firebase_admin import firestore | |
| import json | |
| service_account_info = json.loads(os.environ["SERVICE_ACCOUNT_JSON"]) | |
| cred = credentials.Certificate(service_account_info) | |
| app = firebase_admin.initialize_app(cred) | |
| # Create a Firestore client | |
| self.db = firestore.client() | |
| m_inter = transformers.LlamaForCausalLM.from_pretrained("/models/model") | |
| self.tokenizer = transformers.AutoTokenizer.from_pretrained("/models/model") | |
| m_inter = m_inter.half() | |
| self.model = m_inter.to("cuda") | |
| @stub.function(gpu=modal.gpu.A10G(count=1), shared_volumes={"/models": volume},secrets=[modal.Secret.from_name("firebase-svc")],container_idle_timeout=1200,timeout=500,concurrency_limit=1) | |
| def create_conversation(self,init_context: str,wake: bool): | |
| import random | |
| import traceback | |
| if wake: # just a way to wake up this function! | |
| return | |
| ctx = '' | |
| # conditionally get 'background' context on chat if desired, helpful to keep conversations going across multiple prompts. | |
| background = self.get_firestore_context() | |
| if len(background) > 0: | |
| ctx = background + '\n' + init_context | |
| else: | |
| ctx = init_context | |
| print(ctx) | |
| counter = 0 | |
| backup_counter = 0 | |
| most_recent_sender = init_context.split(":")[0] | |
| most_recent_message = init_context.split(":")[1] | |
| # quick and dirty loop to generate an entire conversation. These probabilities are based off the actual distribution of messages in the chat archive. | |
| while counter <= 12 and backup_counter <= 40: | |
| try: | |
| backup_counter += 1 #prevent infinite loops due to reaction chains | |
| characters = ['Wyatt','Kiebs','Izzy','Luke','Harvey','Henry'] | |
| character_probabilities = [0.15,0.4,0.6,0.1,0.3,0.6] | |
| most_recent_index = characters.index(most_recent_sender) | |
| if counter == 0: | |
| character_probabilities[most_recent_index] = 0 | |
| else: | |
| character_probabilities[most_recent_index] += .2 | |
| most_recent_referenced = '' | |
| if 'adam' in most_recent_message or 'Adam' in most_recent_message or 'kiebs' in most_recent_message or 'Kiebs' in most_recent_message: | |
| most_recent_referenced = 'Kiebs' | |
| elif 'wyatt' in most_recent_message or 'Wyatt' in most_recent_message: | |
| most_recent_referenced = 'Wyatt' | |
| elif 'izzy' in most_recent_message or 'Izzy' in most_recent_message or 'iz' in most_recent_message: | |
| most_recent_referenced = 'Izzy' | |
| elif 'luke' in most_recent_message or 'Luke' in most_recent_message: | |
| most_recent_referenced = 'Luke' | |
| elif 'harv' in most_recent_message or 'Harv' in most_recent_message: | |
| most_recent_referenced = 'Harvey' | |
| elif 'hen' in most_recent_message or 'Hen' in most_recent_message or 'Hank' in most_recent_message: | |
| most_recent_referenced = 'Henry' | |
| if len(most_recent_referenced) > 0: | |
| referenced_index = characters.index(most_recent_referenced) | |
| character_probabilities[referenced_index] += .7 | |
| character = random.choices(characters,character_probabilities)[0] | |
| res = self.predict(context=ctx,character=character) | |
| temp = '' | |
| for i in res.split("###")[-2:]: | |
| temp += i | |
| if len(temp.split("Response:")) < 2: | |
| print(temp) | |
| print('split: ',temp.split("Response:")) | |
| print('no completion generated, skipping') | |
| continue | |
| temp = temp.split("Response:")[1] | |
| temp = temp.replace("</s>","") | |
| if u'\uFFFC' in temp: #this is the character used to represent images in the model, unnecessary if you cleaned them out prior. | |
| continue | |
| if 'https://' in temp: | |
| print('just link, not incrementing counter') | |
| continue | |
| if 'Loved' in temp or 'Laughed' in temp or 'Disliked' in temp or 'Emphasized' in temp or 'Liked' in temp: | |
| print('suppressing reaction') | |
| continue | |
| m = self.dispatch_msg_to_firestore(temp,character) | |
| text = f"{character}: {temp}" | |
| ctx += '\n' | |
| ctx += text | |
| ctx = ctx.replace(' "Input:\n','') | |
| counter+= 1 | |
| except Exception: | |
| print('failed message',traceback.format_exc()) | |
| break | |
| return ctx | |
| #exists only to wake the container | |
| @stub.function(gpu=modal.gpu.A10G(count=1), shared_volumes={"/models": volume},secrets=[modal.Secret.from_name("firebase-svc")],container_idle_timeout=1200,timeout=500,concurrency_limit=1) | |
| def wake(): | |
| print('waking up') | |
| def predict(self, context: str,character: str,temp=.8): | |
| #unnecessary old code from alpaca, just haven't cleaned it yet. | |
| PROMPT_DICT = { | |
| "prompt_input": ( | |
| "Below is an instruction that describes a task, paired with an input that provides further context. " | |
| "Write a response that appropriately completes the request.\n\n" | |
| "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" | |
| ), | |
| "prompt_no_input": ( | |
| "Below is an instruction that describes a task. " | |
| "Write a response that appropriately completes the request.\n\n" | |
| "### Instruction:\n{instruction}\n\n### Response:" | |
| ), | |
| } | |
| PROMPT = PROMPT_DICT['prompt_input'] | |
| characters = ['Wyatt','Kiebs','Izzy','Luke','Harvey','Henry'] | |
| others = list(filter(lambda x: x!=character,characters)) | |
| prompt = f"Your name is {character}. You are in a group chat with 5 of your best friends: {', '.join(others)}. You all went to college together, and talk to each other with no filter. You are encouraged to curse, say generally inappropriate and hilarious things, or be very rude. Everything is in good fun, so remember to joke and laugh, and be funny, but also answer all questions clearly. You will be presented with the most recent messages in the group chat. Write a response to the conversation as {character}." | |
| format_prompt = PROMPT.format_map({'instruction': prompt,'input':context}) | |
| tokenized_text = self.tokenizer(format_prompt, return_tensors="pt", padding="longest", max_length=self.tokenizer.model_max_length, truncation=True) | |
| full_completion = self.model.generate(inputs=tokenized_text["input_ids"].to("cuda"), | |
| attention_mask=tokenized_text["attention_mask"].to("cuda"), | |
| temperature=.75, | |
| top_p=0.85, | |
| top_k=80, | |
| do_sample=True, | |
| num_beams=3, | |
| max_new_tokens=600, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| repetition_penalty=1) | |
| decoded_text = self.tokenizer.decode(full_completion[0]) | |
| return decoded_text | |
| def dispatch_msg_to_firestore(self,message,sender): | |
| from datetime import datetime,timezone | |
| import time | |
| # I delay to make the conversation more realistic on the front-end. Could save a ton of money probably by doing this delay on the frontend instead! | |
| time.sleep(0.25) | |
| senders = { | |
| 'Henry': { | |
| 'uid': 'fake-henry', | |
| 'photo': 'https://i.imgur.com/wdXWHz2.jpg', | |
| 'email': 'fake@email.com', | |
| 'displayName': 'Henry' | |
| }, | |
| 'Harvey': { | |
| 'uid': 'fake-harvey', | |
| 'photo': 'https://i.imgur.com/sU8Codw.jpg', | |
| 'email': 'fake@email.com', | |
| 'displayName': 'Harvey' | |
| }, | |
| 'Luke': { | |
| 'uid': 'fake-luke', | |
| 'photo': 'https://i.imgur.com/U645ciG.jpg', | |
| 'email': 'fake@email.com', | |
| 'displayName': 'Luke' | |
| }, | |
| 'Izzy': { | |
| 'uid': 'fake-izzy', | |
| 'photo': 'https://i.imgur.com/wUGEnVb.jpg', | |
| 'email': 'fake@email.com', | |
| 'displayName': 'Izzy' | |
| }, | |
| 'Kiebs': { | |
| 'uid': 'fake-kiebs', | |
| 'photo': 'https://i.imgur.com/ESoUipA.png', | |
| 'email': 'fake@email.com', | |
| 'displayName': 'Kiebs' | |
| }, | |
| 'Wyatt': { | |
| 'uid': 'fake-wyatt', | |
| 'photo': 'https://i.imgur.com/9yPKaac.jpg', | |
| 'email': 'fake@email.com', | |
| 'displayName': 'Wyatt' | |
| } | |
| } | |
| sender = senders[sender] | |
| chat_doc_ref = self.db.collection('chats').document('<chatdb>') | |
| chat_messages_ref = chat_doc_ref.collection('messages') | |
| create_time, doc_ref = chat_messages_ref.add({ | |
| 'timestamp': datetime.now(timezone.utc), | |
| 'message': message, | |
| 'uid': sender['uid'], | |
| 'photo': sender['photo'], | |
| 'email': sender['email'], | |
| 'displayName': sender['displayName'], | |
| }) | |
| return create_time | |
| def get_firestore_context(self): | |
| from firebase_admin import firestore | |
| from datetime import datetime, timedelta,timezone | |
| chat_doc_ref = self.db.collection('chats').document('<chatdb>') | |
| chat_messages_ref = chat_doc_ref.collection('messages') | |
| most_recent_message = chat_messages_ref.order_by('timestamp', direction=firestore.Query.DESCENDING).limit(1).get()[0] | |
| message_timestamp = most_recent_message.get('timestamp') | |
| current_time = datetime.now(timezone.utc) | |
| time_diff = current_time - message_timestamp | |
| if time_diff <= timedelta(minutes=4): | |
| messages = chat_messages_ref.order_by('timestamp', direction=firestore.Query.DESCENDING).limit(10).get() | |
| ctx = '' | |
| prev = '' | |
| for i in messages: | |
| raw = i.to_dict() | |
| if prev == raw['message']: | |
| return '' | |
| msg = f"{raw['displayName']} : {raw['message']}" | |
| ctx += msg | |
| ctx += '\n' | |
| prev = raw['message'] | |
| return ctx | |
| else: | |
| return '' | |
| # just for testing | |
| @stub.webhook | |
| def get_completion(context: str): | |
| from fastapi.responses import HTMLResponse | |
| convo = MessagePrediction().create_conversation.call(init_context=context, wake=False) | |
| to_render = convo.replace("\n", "<br />") | |
| return HTMLResponse(to_render) | |
| @stub.webhook(label="alive", image=modal.Image.debian_slim()) | |
| def check_alive(): | |
| print('Checking status of GPU container') | |
| status = MessagePrediction().create_conversation.get_current_stats() | |
| return status | |
| @stub.webhook(label="wake") | |
| def wake(): | |
| MessagePrediction().create_conversation.spawn(init_context='wake', wake=True) | |
| print('waking up container') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment