Skip to content

Instantly share code, notes, and snippets.

@meyer1994
Last active January 4, 2024 20:59
Show Gist options
  • Save meyer1994/0d721985027cbb95b77c55e8b59520aa to your computer and use it in GitHub Desktop.
Save meyer1994/0d721985027cbb95b77c55e8b59520aa to your computer and use it in GitHub Desktop.
Playing around with langchain
from langchain.schema import StrOutputParser
from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_community.chat_models import ChatOpenAI
llm = ChatOpenAI(api_key='MY_COOL_API_KEY')
PROMPT_SYSTEM = '''
You are an experienced SQL developer and database specialist.
You will receive a schema and a question.
You will convert questions to SQL queries based on the schema.
You will only output the SQL query.
'''
PROMPT_USER = '''
# Table Schema
```sql
{schema}
```
# Question
> {question}
'''
def get_schema(_) -> str:
return '''
CREATE TABLE users (
id INTEGER PRIMARY KEY,
name TEXT,
birthday DATETIME
)
'''
prompt = ChatPromptTemplate.from_messages([
SystemMessagePromptTemplate.from_template(PROMPT_SYSTEM),
HumanMessagePromptTemplate.from_template(PROMPT_USER),
])
chain = (
RunnablePassthrough.assign(schema=get_schema)
| prompt
| llm
| StrOutputParser()
)
res = chain.invoke({'question': 'Show me the month with most users born in'})
print(res)
# SELECT strftime('%m', birthday) AS birth_month, COUNT(*) AS num_users
# FROM users
# GROUP BY birth_month
# ORDER BY num_users DESC
# LIMIT 1;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment