Last active
June 27, 2024 12:16
-
-
Save JGalego/73d377b86e363586f95a361fc1085cde to your computer and use it in GitHub Desktop.
Chat with a SLM running on AWS Lambda π€
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
r""" | |
_____ _ __ __ _ _ _____ _ _ | |
/ ____| | | \/ | | | | / ____| | | | | |
| (___ | | __ _| \ / | |__ __| | __ _ | | | |__ __ _| |_ | |
\___ \| | / _` | |\/| | '_ \ / _` |/ _` | | | | '_ \ / _` | __| | |
____) | |___| (_| | | | | |_) | (_| | (_| | | |____| | | | (_| | |_ | |
|_____/|______\__,_|_| |_|_.__/ \__,_|\__,_| \_____|_| |_|\__,_|\__| | |
""" | |
import os | |
import json | |
import boto3 | |
import streamlit as st | |
st.title("SLaMbda Chat π€") | |
st.markdown("Learn how to run small language models (SLMs) at scale on [AWS Lambda](https://aws.amazon.com/lambda/).") | |
try: | |
slambda_function_name = os.environ['FUNCTION_NAME'] | |
except KeyError as exc: | |
raise ValueError("SLaMbda function name is not defined!") from exc | |
# Initialize Lambda Client | |
client = boto3.client('lambda') | |
# Initialize message history for display purposes | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
# Display all messages | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# Get user input | |
if prompt := st.chat_input("What's up, doc? π°π₯"): | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
# Process model response | |
with st.chat_message("assistant"): | |
payload = {"body": json.dumps({"message": prompt})} | |
response_stream = client.invoke_with_response_stream( | |
FunctionName=slambda_function_name, | |
Payload=bytes(json.dumps(payload), encoding='utf8') | |
) | |
event_stream = response_stream.get('EventStream', {}) | |
def stream_data(): | |
"""Creates a SLaMbda output generator for Streamlit""" | |
for chunk in event_stream: | |
try: | |
yield chunk['PayloadChunk']['Payload'].decode() | |
except KeyError: | |
pass | |
response = st.write_stream(stream_data) | |
st.session_state.messages.append({"role": "assistant", "content": response}) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment