Created
May 3, 2024 12:53
-
-
Save johnjosephhorton/a29f43ccf226379aaae5d6447b67c1cf to your computer and use it in GitHub Desktop.
Demonstrate using ggplot2 and SQL together
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This gives some examples of how to use SQL and ggplot2 together in the Results class. | |
Note that rather than call LLMs, we are created an agent with a direct-answering capability. | |
After we create the agent, we create a `Results` object. | |
""" | |
import random | |
from edsl import QuestionNumerical, QuestionFreeText | |
from edsl import Agent | |
from edsl.data import Cache | |
from textwrap import dedent | |
c = Cache() # pass a cache object to keep from using caching | |
def f(self, question, scenario): | |
if question.question_name == "age": | |
return random.randint(20, 90) | |
elif question.question_name == "nationality": | |
return random.choice(["USA", "UK", "France", "Germany"]) | |
agent = Agent() | |
agent.add_direct_question_answering_method(f) | |
q_age = QuestionNumerical( | |
question_text = "How old are you?", | |
question_name = "age") | |
q_country = QuestionFreeText( | |
question_text = "What country are you from?", | |
question_name = "nationality") | |
results = q_age.add_question(q_country).by(agent).run(n = 100, cache = c) | |
results.sql('select age from self', shape = 'wide') | |
# This will return a plot of the age distribution, faceted by country | |
results.ggplot2("""ggplot(data = self) + | |
facet_wrap(~nationality, ncol = 1) + | |
geom_histogram(aes(x = age), bins = 10)""") | |
# Now, we write some SQL that take the average by group | |
age_by_country_sql = dedent("""\ | |
select avg(age) as avg_age, | |
nationality | |
from self | |
group by nationality | |
""") | |
results.sql(age_by_country_sql) | |
## We can then use this sql to pre-process the raw Results data befor plotting | |
results.ggplot2("ggplot(data = self) + geom_bar(aes(x = nationality, y = avg_age), stat = 'identity')", | |
sql = age_by_country_sql) | |
## If we want to debug the R code, we can set the debug parameter to True | |
results.ggplot2("ggplot(data = self) + geom_bar(aes(x = nationality, y = avg_age), stat = 'identity')", | |
sql = age_by_country_sql, | |
remove_prefix = True, | |
shape = "wide", | |
debug = True) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment