Skip to content

Instantly share code, notes, and snippets.

@saurav-c
Last active January 12, 2023 20:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save saurav-c/f9a5877b9cb4712a297d82c3cfe654eb to your computer and use it in GitHub Desktop.
Save saurav-c/f9a5877b9cb4712a297d82c3cfe654eb to your computer and use it in GitHub Desktop.
Example Airflow Workflow
import io
import pandas as pd
from sklearn.linear_model import LinearRegression
from airflow.models import DAG
from airflow.operators.python import PythonVirtualenvOperator
from airflow.contrib.hooks.snowflake_hook import SnowflakeHook
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
# Assume that you already created connections for Snowflake and AWS S3
# following the instructions at: https://airflow.apache.org/docs/apache-airflow/stable/howto/connection.html.
# This is part of Airflow's API for credential management.
SNOWFLAKE_CONN_ID = "snowflake_default_conn"
S3_CONN_ID = "s3_default_conn"
S3_BUCKET_NAME = "prod_bucket"
S3_DATA_KEY = "us_customers"
S3_PREDICTIONS_KEY = "predictions"
SNOWFLAKE_PREDICTIONS_TABLE = "predictions"
# Assume the model is trained with some sample data
data = pd.read_csv('data.csv')
X = data.iloc[:, 0].values.reshape(-1, 1)
Y = data.iloc[:, 1].values.reshape(-1, 1)
linear_model = LinearRegression()
linear_model.fit(X, Y)
dag = DAG(
dag_id='prediction_workflow',
default_args={
'retries': 0,
},
start_date=datetime(2023, 1, 1, 1),
schedule_interval='0 8 * * *',
)
# The task_id is needed as a unique task identifier for this DAG
@task(task_id='extract')
def extract():
'''
Performs the entire data extraction stage. This includes
reading the data from the data warehouse and then writing
it to a location that can be accessed by subsequent tasks.
'''
sf_hook = SnowflakeHook(snowflake_conn_id=SNOWFLAKE_CONN_ID)
with sf_hook.get_conn() as conn:
with conn.cursor() as cur:
cur.execute("SELECT * FROM customers WHERE location = 'US';")
# res is a list of dictionaries that map column name to value
res = cur.fetchall()
# Determine the column names
col_names = list(map(lambda t: t[0], cur.description))
df = pd.DataFrame(res)
df.columns = col_names
# Save the DataFrame to AWS S3, so it can be accessed by the next operator
s3_hook = S3Hook(aws_conn_id=S3_CONN_ID)
buffer = io.BytesIO()
df.to_parquet(buffer, index=False)
buffer.seek(0, 0) # Reset stream position of buffer
s3_hook.load_file_obj(buffer, S3_DATA_KEY, S3_BUCKET_NAME)
@task(task_id='predict')
def predict(task_id='predict'):
'''
Performs linear regression on the input data.
Predictions are added as a new column named `score` in the input DataFrame.
In addition, we need to read the input data from the
location it was written to by the previous extract task.
The output then needs to be written to a location from where it can
be accessed by subsequent tasks.
'''
# Load the DataFrame from AWS S3
s3_hook = S3Hook(aws_conn_id=S3_CONN_ID)
data_bytes = s3_hook.get_key(S3_DATA_KEY, S3_BUCKET_NAME)
data_file = io.BytesIO(data_bytes)
df = pd.read_parquet(data_file)
df['score'] = pd.DataFrame({"linear": linear_model.predict_proba(df)[:, 1]})
# Save the predictions to AWS S3, so it can be accessed by the next operator
buffer = io.BytesIO()
df.to_parquet(buffer, index=False)
buffer.seek(0, 0) # Reset stream position of buffer
s3_hook.load_file_obj(buffer, S3_PREDICTIONS_KEY, S3_BUCKET_NAME)
@task(task_id='save')
def save(task_id='save'):
'''
Performs the data saving stage. This involves first reading the output of the
previous stage and then writing it to the data warehouse.
'''
# Load the predictions DataFrame from AWS S3
s3_hook = S3Hook(aws_conn_id=S3_CONN_ID)
data_bytes = s3_hook.get_key(S3_PREDICTIONS_KEY, S3_BUCKET_NAME)
data_file = io.BytesIO(data_bytes)
df = pd.read_parquet(data_file)
# Save the predictions into a Snowflake table
sf_hook = SnowflakeHook(snowflake_conn_id=SNOWFLAKE_CONN_ID)
df.to_sql(
SNOWFLAKE_PREDICTIONS_TABLE,
con=sf_hook.get_conn(),
index=False,
)
extract_stage = PythonVirtualenvOperator(
task_id='extract',
python_callable=extract,
requirements=["pandas", "snowflake-sqlalchemy", "SQLAlchemy"],
dag=dag,
)
prediction_stage = PythonVirtualenvOperator(
task_id='prediction',
python_callable=predict,
requirements=["pandas", "scikit-learn"],
dag=dag,
)
save_stage = PythonVirtualenvOperator(
task_id='prediction',
python_callable=save,
requirements=["pandas", "snowflake-sqlalchemy", "SQLAlchemy"],
dag=dag,
)
extract_stage >> prediction_stage >> save_stage
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment