Skip to content

Instantly share code, notes, and snippets.

@nownabe
Created May 15, 2023 07:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nownabe/97e4bd8ea600d71bbfac12daf0c4110e to your computer and use it in GitHub Desktop.
Save nownabe/97e4bd8ea600d71bbfac12daf0c4110e to your computer and use it in GitHub Desktop.
SQL Generation by PaLM API
import io
import sys
from google.cloud import bigquery
from vertexai.preview.language_models import TextGenerationModel
client = bigquery.Client()
def render_prompt(full_table_id: str, user_prompt: str) -> str:
table = client.get_table(full_table_id)
restriction = "None"
if (partitioning := table.time_partitioning) and partitioning.require_partition_filter:
restriction = f"You MUST filter `{partitioning.field}` field in a `where` clause."
with io.StringIO("") as buf:
client.schema_to_json(table.schema, buf)
schema = buf.getvalue()
return f"""You are an experienced data analyst. Write a BigQuery SQL to answer the user's prompt based on the following context.
---- Context ----
Format: Plain SQL only, no Markdown
Table: `{table.full_table_id}`
Restriction: {restriction}
Schema as JSON:
{schema}
----
User's prompt: {user_prompt}"""
full_table_id = sys.argv[1]
user_prompt = sys.argv[2]
prompt = render_prompt(full_table_id, user_prompt)
print(f"==== Prompt ====\n{prompt}")
model = TextGenerationModel.from_pretrained("text-bison@001")
result = model.predict(prompt)
print(f"\n==== Response ====\n{result.text}")
rows = client.query(result.text).result()
print(f"\n==== Result ====\n")
for row in rows:
print("\t".join(map(str, row.values())))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment