Skip to content

Instantly share code, notes, and snippets.

@johnjosephhorton
Created May 3, 2024 12:53
Show Gist options
  • Save johnjosephhorton/a29f43ccf226379aaae5d6447b67c1cf to your computer and use it in GitHub Desktop.
Save johnjosephhorton/a29f43ccf226379aaae5d6447b67c1cf to your computer and use it in GitHub Desktop.
Demonstrate using ggplot2 and SQL together
"""
This gives some examples of how to use SQL and ggplot2 together in the Results class.
Note that rather than call LLMs, we are created an agent with a direct-answering capability.
After we create the agent, we create a `Results` object.
"""
import random
from edsl import QuestionNumerical, QuestionFreeText
from edsl import Agent
from edsl.data import Cache
from textwrap import dedent
c = Cache() # pass a cache object to keep from using caching
def f(self, question, scenario):
if question.question_name == "age":
return random.randint(20, 90)
elif question.question_name == "nationality":
return random.choice(["USA", "UK", "France", "Germany"])
agent = Agent()
agent.add_direct_question_answering_method(f)
q_age = QuestionNumerical(
question_text = "How old are you?",
question_name = "age")
q_country = QuestionFreeText(
question_text = "What country are you from?",
question_name = "nationality")
results = q_age.add_question(q_country).by(agent).run(n = 100, cache = c)
results.sql('select age from self', shape = 'wide')
# This will return a plot of the age distribution, faceted by country
results.ggplot2("""ggplot(data = self) +
facet_wrap(~nationality, ncol = 1) +
geom_histogram(aes(x = age), bins = 10)""")
# Now, we write some SQL that take the average by group
age_by_country_sql = dedent("""\
select avg(age) as avg_age,
nationality
from self
group by nationality
""")
results.sql(age_by_country_sql)
## We can then use this sql to pre-process the raw Results data befor plotting
results.ggplot2("ggplot(data = self) + geom_bar(aes(x = nationality, y = avg_age), stat = 'identity')",
sql = age_by_country_sql)
## If we want to debug the R code, we can set the debug parameter to True
results.ggplot2("ggplot(data = self) + geom_bar(aes(x = nationality, y = avg_age), stat = 'identity')",
sql = age_by_country_sql,
remove_prefix = True,
shape = "wide",
debug = True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment