Skip to content

Instantly share code, notes, and snippets.

@maybe-hello-world
Created May 20, 2024 21:02
Show Gist options
  • Save maybe-hello-world/37b484023d7b8274efc2cb6730ec51d7 to your computer and use it in GitHub Desktop.
Save maybe-hello-world/37b484023d7b8274efc2cb6730ec51d7 to your computer and use it in GitHub Desktop.
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.metrics import accuracy_score
from sklearn import tree
import graphviz
from trustee import ClassificationTrustee
from trustee.report.trust import TrustReport
OUTPUT_PATH = "out/"
REPORT_PATH = f"{OUTPUT_PATH}/report/trust_report.obj"
# Load the Iris dataset
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
pipeline = Pipeline([
('scaler', StandardScaler()), # Step 1: Standardize the data
('classifier', LogisticRegression()) # Step 2: Apply logistic regression
])
pipeline.fit(X_train, y_train)
y_pred = pipeline.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f'Accuracy: {accuracy:.2f}')
trust_report = TrustReport(
pipeline,
X_train=X_train,
y_train=y_train,
X_test=X_test,
y_test=y_test,
max_iter=5,
num_pruning_iter=5,
train_size=0.7,
trustee_num_iter=5,
trustee_num_stability_iter=5,
trustee_sample_size=0.3,
analyze_branches=True,
analyze_stability=True,
top_k=1,
verbose=True,
class_names=iris.target_names,
feature_names=iris.feature_names,
is_classify=True,
)
trust_report.save(OUTPUT_PATH)
trustee = ClassificationTrustee(expert=pipeline)
trustee.fit(X_train, y_train, num_iter=5, num_stability_iter=5, samples_size=0.3, verbose=True)
dt, pruned_dt, agreement, reward = trustee.explain()
print(f"Model explanation training (agreement, fidelity): ({agreement}, {reward})")
print(f"Model Explanation size: {dt.tree_.node_count}")
print(f"Top-k Prunned Model explanation size: {pruned_dt.tree_.node_count}")
# Output decision tree to pdf
dot_data = tree.export_graphviz(
dt,
class_names=iris.target_names,
feature_names=iris.feature_names,
filled=True,
rounded=True,
special_characters=True,
)
graph = graphviz.Source(dot_data)
graph.render("dt_explanation")
# Output pruned decision tree to pdf
dot_data = tree.export_graphviz(
pruned_dt,
class_names=iris.target_names,
feature_names=iris.feature_names,
filled=True,
rounded=True,
special_characters=True,
)
graph = graphviz.Source(dot_data)
graph.render("pruned_dt_explation")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment