-
-
Save sqlinsights/c9918ce4b62c8d8b97e216534a9b0768 to your computer and use it in GitHub Desktop.
Simple Streamlit in Snowflake Cortex-Driven toolbox.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import streamlit as st | |
from snowflake.snowpark.context import get_active_session | |
from snowflake.cortex import Complete | |
from enum import Enum | |
if 'prompt_history' not in st.session_state: | |
st.session_state['prompt_history'] = [{"response" : "**Welcome!** \n\n Here you can ask anything you need."}] | |
if 'workflow' not in st.session_state: | |
st.session_state['workflow'] = None | |
session = get_active_session() | |
MODEL = 'reka-flash' | |
class SystemMessages(Enum): | |
EXPLAIN = 'Provide a detailed explanation including Language written-in. Must have a General Explanation of the code, breakdown of steps and an output section. Use Markdown as the output with header and subheaders for each section' | |
FORMAT = 'Provide a formatted version of the code. Return in a markdown style with a codeblock. Do not truncate output' | |
def submit_prompt(prompt:str)->None: | |
try: | |
with st.spinner("Awaiting response"): | |
response = Complete(MODEL, st.session_state[prompt]).strip() | |
st.session_state['prompt_history'].append(dict(prompt=st.session_state[prompt], response = response)) | |
except Exception as e: | |
st.exception(e) | |
def compose_prompt(prompt:str ,workflow:SystemMessages, **kwargs)->str: | |
full_prompt = [{'role': 'system', 'content': workflow.value }] | |
if 'language' in kwargs: | |
lang_spec = [{'role': 'system', 'content': f'Format using{kwargs.get("language")} as the language' }] | |
else: | |
lang_spec = [] | |
user_prompt = [{'role': 'user', 'content': prompt}] | |
full_prompt.extend(lang_spec) | |
full_prompt.extend(user_prompt) | |
return str(full_prompt) | |
def utility_prompt(source:str, workflow:SystemMessages, **kwargs)->str: | |
try: | |
prompt = compose_prompt(source, workflow, **kwargs) | |
return Complete(model=MODEL, prompt=prompt ).strip() | |
except Exception as e: | |
st.toast("Error generating respose") | |
st.toast(e) | |
return '' | |
def clear_chat()->None: | |
st.session_state.pop('prompt_history') | |
def set_workflow(flow:str): | |
st.session_state["workflow"] = flow | |
def reformat(): | |
source = st.text_area("Enter code to re-format", height=300) | |
format_menu = st.columns(2) | |
with format_menu[0]: | |
formatting_labels = ["Auto-Detect", "Python", "SQL", "Snowflake SQL", "JSON", "YAML", "Other"] | |
lang = st.selectbox("Language", options=[0,1,2,3,4,5,6], format_func= lambda x: formatting_labels[x]) | |
other = format_menu[1].text_input("Other", disabled=lang!=formatting_labels[lang]!="Other", key=f"fmt_{lang}") | |
if lang > 0: | |
lang_options = {"language":formatting_labels[lang] if formatting_labels[lang] != "Other" else other} | |
else: | |
lang_options = {} | |
if st.button("Format", use_container_width = True, type = "primary", disabled = not source): | |
with st.spinner("Formatting Code"): | |
st.markdown(utility_prompt(source, SystemMessages.FORMAT, **lang_options)) | |
def explain(): | |
source = st.text_area("Enter code to explain", height=300, max_chars=4000) | |
if st.button("Explain", use_container_width = True, type = "primary", disabled = not source): | |
with st.spinner("Reading Code"): | |
st.markdown(utility_prompt(source, SystemMessages.EXPLAIN)) | |
def ask(): | |
toolbar_cols = st.columns(3) | |
show_history = toolbar_cols[0].toggle("Show chat history") | |
toolbar_cols[2].button("Clear", on_click=clear_chat, use_container_width = True) | |
display_chat = st.session_state['prompt_history']if show_history else st.session_state['prompt_history'][-1:] | |
for chat in display_chat: | |
if chat.get('prompt'): | |
with st.chat_message("user"): | |
st.write(chat.get('prompt')) | |
if chat.get('response'): | |
with st.chat_message("ai"): | |
st.write(chat.get('response')) | |
st.chat_input("Ask me anything", | |
on_submit= submit_prompt, | |
key='prompt_input', | |
args=['prompt_input']) | |
actions = {"explain":explain, | |
"format":reformat, | |
"ask":ask} | |
st.title(f"${'~'*8}$Cortex-Driven Toolbox") | |
headers = st.columns((.98,2), gap="small") | |
headers[0].info(f"${'~'*15}$Generic") | |
headers[1].success(f"${'~'*35}$Code Tools") | |
dashboard = st.columns(3) | |
with dashboard[0]: | |
st.button(f"$~~~$ \n\r🤖\n\n $~~~$ \n\n **:{'blue' if st.session_state['workflow'] == 'ask' else 'grey'}[CHATBOT]**\n\n$~~~$ \n\n", | |
use_container_width=True, | |
on_click=set_workflow, args=['ask']) | |
with dashboard[1]: | |
st.button(f"$~~~$ \n\nℹ️\n\n $~~~$ \n\n **:{'blue' if st.session_state['workflow'] == 'explain' else 'grey'}[EXPLAIN]** \n\n$~~~$ \n\n", | |
use_container_width=True, | |
on_click=set_workflow, args=['explain']) | |
with dashboard[2]: | |
st.button(f"$~~~$ \n\n✏️\n\n $~~~$ \n\n **:{'blue' if st.session_state['workflow'] == 'format' else 'grey'}[FORMAT]** \n\n$~~~$ \n\n", | |
use_container_width=True, | |
on_click=set_workflow, args=['format']) | |
fn = actions.get(st.session_state['workflow']) | |
if fn: | |
fn() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment