-
-
Save saurav-c/f9a5877b9cb4712a297d82c3cfe654eb to your computer and use it in GitHub Desktop.
Example Airflow Workflow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import io | |
import pandas as pd | |
from sklearn.linear_model import LinearRegression | |
from airflow.models import DAG | |
from airflow.operators.python import PythonVirtualenvOperator | |
from airflow.contrib.hooks.snowflake_hook import SnowflakeHook | |
from airflow.providers.amazon.aws.hooks.s3 import S3Hook | |
# Assume that you already created connections for Snowflake and AWS S3 | |
# following the instructions at: https://airflow.apache.org/docs/apache-airflow/stable/howto/connection.html. | |
# This is part of Airflow's API for credential management. | |
SNOWFLAKE_CONN_ID = "snowflake_default_conn" | |
S3_CONN_ID = "s3_default_conn" | |
S3_BUCKET_NAME = "prod_bucket" | |
S3_DATA_KEY = "us_customers" | |
S3_PREDICTIONS_KEY = "predictions" | |
SNOWFLAKE_PREDICTIONS_TABLE = "predictions" | |
# Assume the model is trained with some sample data | |
data = pd.read_csv('data.csv') | |
X = data.iloc[:, 0].values.reshape(-1, 1) | |
Y = data.iloc[:, 1].values.reshape(-1, 1) | |
linear_model = LinearRegression() | |
linear_model.fit(X, Y) | |
dag = DAG( | |
dag_id='prediction_workflow', | |
default_args={ | |
'retries': 0, | |
}, | |
start_date=datetime(2023, 1, 1, 1), | |
schedule_interval='0 8 * * *', | |
) | |
# The task_id is needed as a unique task identifier for this DAG | |
@task(task_id='extract') | |
def extract(): | |
''' | |
Performs the entire data extraction stage. This includes | |
reading the data from the data warehouse and then writing | |
it to a location that can be accessed by subsequent tasks. | |
''' | |
sf_hook = SnowflakeHook(snowflake_conn_id=SNOWFLAKE_CONN_ID) | |
with sf_hook.get_conn() as conn: | |
with conn.cursor() as cur: | |
cur.execute("SELECT * FROM customers WHERE location = 'US';") | |
# res is a list of dictionaries that map column name to value | |
res = cur.fetchall() | |
# Determine the column names | |
col_names = list(map(lambda t: t[0], cur.description)) | |
df = pd.DataFrame(res) | |
df.columns = col_names | |
# Save the DataFrame to AWS S3, so it can be accessed by the next operator | |
s3_hook = S3Hook(aws_conn_id=S3_CONN_ID) | |
buffer = io.BytesIO() | |
df.to_parquet(buffer, index=False) | |
buffer.seek(0, 0) # Reset stream position of buffer | |
s3_hook.load_file_obj(buffer, S3_DATA_KEY, S3_BUCKET_NAME) | |
@task(task_id='predict') | |
def predict(task_id='predict'): | |
''' | |
Performs linear regression on the input data. | |
Predictions are added as a new column named `score` in the input DataFrame. | |
In addition, we need to read the input data from the | |
location it was written to by the previous extract task. | |
The output then needs to be written to a location from where it can | |
be accessed by subsequent tasks. | |
''' | |
# Load the DataFrame from AWS S3 | |
s3_hook = S3Hook(aws_conn_id=S3_CONN_ID) | |
data_bytes = s3_hook.get_key(S3_DATA_KEY, S3_BUCKET_NAME) | |
data_file = io.BytesIO(data_bytes) | |
df = pd.read_parquet(data_file) | |
df['score'] = pd.DataFrame({"linear": linear_model.predict_proba(df)[:, 1]}) | |
# Save the predictions to AWS S3, so it can be accessed by the next operator | |
buffer = io.BytesIO() | |
df.to_parquet(buffer, index=False) | |
buffer.seek(0, 0) # Reset stream position of buffer | |
s3_hook.load_file_obj(buffer, S3_PREDICTIONS_KEY, S3_BUCKET_NAME) | |
@task(task_id='save') | |
def save(task_id='save'): | |
''' | |
Performs the data saving stage. This involves first reading the output of the | |
previous stage and then writing it to the data warehouse. | |
''' | |
# Load the predictions DataFrame from AWS S3 | |
s3_hook = S3Hook(aws_conn_id=S3_CONN_ID) | |
data_bytes = s3_hook.get_key(S3_PREDICTIONS_KEY, S3_BUCKET_NAME) | |
data_file = io.BytesIO(data_bytes) | |
df = pd.read_parquet(data_file) | |
# Save the predictions into a Snowflake table | |
sf_hook = SnowflakeHook(snowflake_conn_id=SNOWFLAKE_CONN_ID) | |
df.to_sql( | |
SNOWFLAKE_PREDICTIONS_TABLE, | |
con=sf_hook.get_conn(), | |
index=False, | |
) | |
extract_stage = PythonVirtualenvOperator( | |
task_id='extract', | |
python_callable=extract, | |
requirements=["pandas", "snowflake-sqlalchemy", "SQLAlchemy"], | |
dag=dag, | |
) | |
prediction_stage = PythonVirtualenvOperator( | |
task_id='prediction', | |
python_callable=predict, | |
requirements=["pandas", "scikit-learn"], | |
dag=dag, | |
) | |
save_stage = PythonVirtualenvOperator( | |
task_id='prediction', | |
python_callable=save, | |
requirements=["pandas", "snowflake-sqlalchemy", "SQLAlchemy"], | |
dag=dag, | |
) | |
extract_stage >> prediction_stage >> save_stage |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment