Last active
April 12, 2021 00:44
-
-
Save AviSoori1x/6d6634582d1c24cec249da9ed40a8cbc to your computer and use it in GitHub Desktop.
ML Explainability App
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
from PIL import Image | |
from swat import CAS, options | |
import json | |
import time | |
import math | |
import SessionState | |
import altair as alt | |
import random | |
import io | |
import plotly.express as px | |
import base64 | |
state = SessionState.get(swat_sess=None,samplepd=None,scored=None, numtopics =None) | |
def get_table_download_link(df): | |
"""Generates a link allowing the data in a given panda dataframe to be downloaded | |
in: dataframe | |
out: href string | |
""" | |
csv = df.to_csv(index=False) | |
b64 = base64.b64encode( | |
csv.encode() | |
).decode() | |
return f'<a href="data:file/csv;base64,{b64}" download="sentiment.csv">Download Sentiment scores file</a>' | |
def get_topic_download_link(df): | |
"""Generates a link allowing the data in a given panda dataframe to be downloaded | |
in: dataframe | |
out: href string | |
""" | |
csv = df.to_csv(index=False) | |
b64 = base64.b64encode( | |
csv.encode() | |
).decode() # some strings <-> bytes conversions necessary here | |
return f'<a href="data:file/csv;base64,{b64}" download="sentiment.csv">Download Topic scores file</a>' | |
image = Image.open('images/sasLogo.png') | |
def dicttopd(datadictionary): | |
for key in datadictionary: | |
datadictionary[key] = [datadictionary[key]] | |
return pd.DataFrame.from_dict(datadictionary) | |
st.markdown( | |
""" | |
<style> | |
.sidebar .sidebar-content { | |
background-image: linear-gradient(#2e7bcf,#2e7bcf); | |
color: white; | |
} | |
.sidebar-content .Widget>label { | |
color: white; | |
} | |
.Widget>label { | |
color: black; | |
} | |
[class^="st-b"] { | |
color: black; | |
} | |
.st-bb { | |
background-color: transparent; | |
} | |
.st-at { | |
background-color: white; | |
} | |
.reportview-container .main footer, .reportview-container .main footer a { | |
color: white; | |
} | |
header .decoration { | |
background-image: none; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) | |
st.sidebar.image(image) | |
st.sidebar.write(""" | |
# Explainable AI App | |
""") | |
hostname = st.sidebar.text_input('Please Enter host URI') | |
username= st.sidebar.text_input('Please Enter host Username') | |
password= st.sidebar.text_input('Please Enter host Password',type="password") | |
def logon(hostname, username, password): | |
s = CAS(hostname=hostname, protocol='cas', | |
username=username, password=password) | |
s.loadActionSet('autotune') | |
s.loadactionset('aStore') | |
s.loadactionset('table') | |
s.loadactionset('decisionTree') | |
s.loadactionset("explainModel") | |
s.loadactionset('table') | |
s.loadActionSet(actionSet="sentimentAnalysis") | |
s.loadActionSet(actionSet="textMining") | |
return s | |
s = None | |
if st.sidebar.button('Login'): | |
try: | |
s = logon(username, password) | |
state.swat_sess= s | |
except: | |
st.sidebar.write('Please Enter a valid user name and password') | |
demo_title = st.sidebar.text_input('Please Enter Demo Title') | |
heading = """ | |
# {} | |
""".format(demo_title) | |
option = st.sidebar.selectbox( | |
'Select the scenario you wish to demonstrate', | |
('Home Loan Default','Concrete Strength Predictor','Telecom Customer Churn','Financial Fraud Detection','Retail Price Forecast','Predict Early Stage Diabetes Risk','Bank Customer Churn','Heart Disease Prediction','NLP : Sentiment','NLP : Sentiment and SVD Topics')) | |
st.sidebar.button('Download App Code') | |
st.sidebar.button('Download App Docker Image') | |
if option !='Financial Fraud Detection': | |
st.write(heading) | |
if option == 'Home Loan Default': | |
st.write('Please choose variable inputs: ') | |
def user_input_features(): | |
LOAN = st.slider('Loan Amount', 1000, 90000, 20000) | |
MORTDUE = st.slider('Mortgage Due', 2000, 400000, 10000) | |
VALUE = st.slider('Value of Property', 6000, 860000, 75000) | |
YOJ = st.slider('Years on the Job', 0, 45, 20) | |
DEROG = st.slider('Number of derogatory remarks', 0, 10, 4) | |
DELINQ = st.slider('Number of delinquencies', 0, 16, 8) | |
CLAGE = st.slider('Age of credit line in months', 0, 240, 12) | |
NINQ = st.slider('Number of inquiries', 0, 20, 10) | |
CLNO = st.slider('Number of credit lines', 0, 75, 35) | |
DEBTINC = st.slider('Debt to Income ratio', 0.5, 205.0, 80.0) | |
JOB = st.selectbox('Job/Occupation', | |
('Mgr', 'Office', 'Other', 'ProfExe', 'Sales', 'Self')) | |
REASON = st.selectbox('Reason', | |
('DebtCon', 'HomeImp')) | |
data = {'LOAN': LOAN, | |
'MORTDUE': MORTDUE, | |
'VALUE': VALUE, | |
'YOJ': YOJ, | |
'DEROG':DEROG, | |
'DELINQ': DELINQ, | |
'CLAGE': CLAGE, | |
'NINQ': NINQ, | |
'CLNO': CLNO, | |
'DEBTINC': DEBTINC, | |
'JOB': JOB, | |
'REASON': REASON, | |
} | |
features = pd.DataFrame(data, index=[0]) | |
return features | |
samplepd = user_input_features() | |
state.samplepd = samplepd | |
st.write('Please verify that the input data is correct: ') | |
st.write(samplepd) | |
try: | |
s = state.swat_sess | |
except: | |
st.write('Please log in to the SAS Viya to use the AI Composer') | |
@st.cache | |
def score(samplepd): | |
s.upload(samplepd,casout={'name' : 'realtime', 'caslib' : 'public','replace' : True}) | |
s.aStore.score(rstore = {"caslib":"public","name":"hmeqTestAstore"}, | |
table = {"caslib":'public',"name":'realtime'}, | |
out = {"caslib":'public',"name":'realscore', 'replace':True}) | |
scoredData = s.CASTable(name='realscore',caslib='public') | |
datasetDict = scoredData.to_dict() | |
scores = pd.DataFrame(datasetDict, index=[0]) | |
return scores | |
@st.cache | |
def explainML(samplepd): | |
s.upload(samplepd,casout={'name' : 'realtime', 'caslib' : 'public','replace' : True}) | |
shapvals = s.linearExplainer( | |
table = {"name" : 'hmeqTest','caslib':'public'}, | |
query = {"name" : 'realtime','caslib':'public'}, | |
modelTable = {"name" :"hmeqTestAstore",'caslib':'public'}, | |
modelTableType = "ASTORE", | |
predictedTarget = 'P_BAD1', | |
seed = 1234, | |
preset = "KERNELSHAP", | |
inputs = ['LOAN','MORTDUE','VALUE','YOJ','DEROG','DELINQ','CLAGE','NINQ','CLNO','DEBTINC','REASON', 'JOB','BAD'], | |
nominals = ['REASON', 'JOB','BAD'] | |
) | |
shap1 = shapvals['ParameterEstimates'] | |
shap = shap1[['Variable','Estimate']][1:10] | |
labels = list(shap['Variable']) | |
data = list(shap['Estimate']) | |
target = 'BAD' | |
labels_n = [] | |
data_n = [] | |
i = 0 | |
for i in range(len(labels)): | |
if labels[i] != target: | |
labels_n.append(labels[i]) | |
data_n.append(data[i]) | |
return labels_n, data_n | |
if st.button('Score'): | |
try: | |
scores = score(samplepd) | |
st.write('The predicted outcome is: ') | |
st.write(scores) | |
samplepd = state.samplepd | |
samplepd['BAD'] = int(scores['I_BAD'][0].strip(" ")) | |
state.scored = samplepd | |
except: | |
st.write("Are you sure you're signed into the SAS server?") | |
if st.button('Explain my prediction'): | |
samplepd = state.scored | |
st.write('The predicted value for target {} is: {}, and the contributions of the variables are as follows:'.format('BAD',samplepd['BAD'][0])) | |
labels, data = explainML(samplepd) | |
source = pd.DataFrame({ | |
'Predictors': labels, | |
'Estimated Impact': data | |
}) | |
c = alt.Chart(source).mark_bar().encode( | |
x='Predictors', | |
y='Estimated Impact', | |
color=alt.Color('Predictors'), tooltip=['Predictors', 'Estimated Impact'] | |
) | |
st.altair_chart(c, use_container_width=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment