Skip to content

Instantly share code, notes, and snippets.

@naturalett
Last active May 31, 2023 21:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save naturalett/1fb33e337d3b664b0f8431613b2d5dea to your computer and use it in GitHub Desktop.
Save naturalett/1fb33e337d3b664b0f8431613b2d5dea to your computer and use it in GitHub Desktop.
Iris Classification
from airflow import DAG
from airflow.operators.python_operator import PythonOperator
from datetime import datetime
import mysql.connector
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
import joblib
from datetime import timedelta
import os
import numpy as np
import random
# Set the Airflow base directory
DAGS_DIR = os.environ.get('AIRFLOW__CORE__DAGS_FOLDER', '/opt/airflow/dags')
def train_and_export_model():
# Load the dataset
iris = load_iris()
X = iris.data
y = iris.target
# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Train the model
model = DecisionTreeClassifier()
print(f"model: {model}")
model.fit(X_train, y_train)
# Generate a unique folder path with timestamp
unique_folder = datetime.now().strftime('%Y%m%d%H%M%S')
folder_path = os.path.join(DAGS_DIR, unique_folder)
# Create the folder if it doesn't exist
os.makedirs(folder_path, exist_ok=True)
# Generate a unique path for the .pkl file
model_filename = 'iris_model.pkl'
model_path = os.path.join(folder_path, model_filename)
# Export the model
joblib.dump(model, model_path)
# Store the model path in MySQL
connection = mysql.connector.connect(
host='mysql.default.svc.cluster.local',
user='root',
password='password',
database='my_database'
)
cursor = connection.cursor()
try:
# Alter the existing table to add the model_path column
alter_table_query = '''
ALTER TABLE models
ADD COLUMN model_path VARCHAR(255)
'''
cursor.execute(alter_table_query)
except mysql.connector.Error as err:
# Handle the error if the column already exists
if err.errno == 1060:
print("Column 'model_path' already exists in table 'models'")
else:
print("Error:", err)
# Create a table to store the model paths
create_table_query = '''
CREATE TABLE IF NOT EXISTS models (
id INT AUTO_INCREMENT PRIMARY KEY,
model_path VARCHAR(255)
)
'''
cursor.execute(create_table_query)
# Insert the model path into the table
insert_query = '''
INSERT INTO models (model_path) VALUES (%s)
'''
cursor.execute(insert_query, (model_path,))
connection.commit()
# Close the connection
cursor.close()
connection.close()
def load_model_from_database():
connection = mysql.connector.connect(
host='mysql.default.svc.cluster.local',
user='root',
password='password',
database='my_database'
)
cursor = connection.cursor()
select_query = '''
SELECT model_path FROM models ORDER BY id DESC LIMIT 1
'''
cursor.execute(select_query)
result = cursor.fetchone()
if result is not None:
model_path = result[0]
# Load the model from the stored path
loaded_model = joblib.load(os.path.join(DAGS_DIR, model_path))
# Use the loaded model for inference or further processing
species_names = ['setosa', 'versicolor', 'virginica']
# Generate a random input for each species
for species in species_names:
sepal_length = random.uniform(4.0, 8.0)
sepal_width = random.uniform(2.0, 4.5)
petal_length = random.uniform(1.0, 7.0)
petal_width = random.uniform(0.1, 2.5)
X_new = np.array([[sepal_length, sepal_width, petal_length, petal_width]]) # Input for prediction
prediction = loaded_model.predict(X_new)
print("Species:", species)
print("Prediction:", prediction)
# Generate a link to view images of the predicted species
species_name = species.lower()
image_link = f"https://en.wikipedia.org/wiki/Iris_{species_name}"
print("Image Link:", image_link)
print()
# Remove the model file
os.remove(os.path.join(DAGS_DIR, model_path))
cursor.close()
connection.close()
default_args = {
'start_date': datetime(2023, 5, 1),
'retries': 3,
'retry_delay': timedelta(minutes=5)
}
dag = DAG(
'iris_classification_try_me_again',
default_args=default_args,
description="An ETL with lineage",
tags=['workshop', 'ETL'],
schedule_interval=None
)
with dag:
train_export_operator = PythonOperator(
task_id='train_and_export',
python_callable=train_and_export_model,
outlets={
"tables": ["mysql.my_database.my_database.models"]
},
inlets={
"tables": ["mysql.my_database.my_database.models"]
}
)
load_model_operator = PythonOperator(
task_id='load_model',
python_callable=load_model_from_database,
)
train_export_operator >> load_model_operator # Define the task dependency
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment