Skip to content

Instantly share code, notes, and snippets.

Last active January 5, 2022 10:23
Show Gist options
  • Save mkaranasou/126a6363d3ab423ecd0728afe588cdeb to your computer and use it in GitHub Desktop.
Save mkaranasou/126a6363d3ab423ecd0728afe588cdeb to your computer and use it in GitHub Desktop.
Calculate the Shapley marginal contribution for each feature of a given dataset-model pair
import os
from psutil import virtual_memory
from pyspark import SparkConf
from import Vectors, VectorUDT
from pyspark.sql import functions as F, SparkSession, types as T, Window
def get_spark_session():
With an effort to optimize memory and partitions
mem = virtual_memory()
conf = SparkConf()
memory = f'{int(round(( / 2) / 1024 / 1024 / 1024, 0))}G'
conf.set('spark.driver.memory', memory)
conf.set('spark.sql.shuffle.partitions', str(os.cpu_count()*2))
return SparkSession \
.builder \
.config(conf=conf) \
.appName("IForest feature importance") \
def get_features_permutations(
Creates a column for the ordered features and then shuffles it.
The result is a dataframe with a column `output_col` that contains:
[feat2, feat4, feat3, feat1],
[feat3, feat4, feat2, feat1],
[feat1, feat2, feat4, feat3],
return df.withColumn(
F.array(*[F.lit(f) for f in feature_names])
def calculate_shapley_values(
# Based on the algorithm described here:
# And on Baskerville's implementation for IForest/ AnomalyModel here:
results = {}
features_perm_col = 'features_permutations'
spark = get_spark_session()
marginal_contribution_filter = F.avg('marginal_contribution').alias(
# broadcast the row of interest and ordered feature names
ROW_OF_INTEREST_BROADCAST = spark.sparkContext.broadcast(row_of_interest)
ORDERED_FEATURE_NAMES = spark.sparkContext.broadcast(feature_names)
# persist before continuing with calculations
if not df.is_cached:
df = df.persist()
# get permutations
features_df = get_features_permutations(
# set up the udf - x-j and x+j need to be calculated for every row
def calculate_x(
feature_j, z_features, curr_feature_perm
The instance x+j is the instance of interest,
but all values in the order before feature j are
replaced by feature values from the sample z
The instance x−j is the same as x+j, but in addition
has feature j replaced by the value for feature j from the sample z
x_interest = ROW_OF_INTEREST_BROADCAST.value.features
ordered_features = ORDERED_FEATURE_NAMES.value
x_minus_j = list(z_features).copy()
x_plus_j = list(z_features).copy()
f_i = curr_feature_perm.index(feature_j)
for i, f in enumerate(curr_feature_perm[f_i:]):
# replace z feature values with x of interest feature values
# iterate features in current permutation until one before j
# x-j = [z1, z2, ... zj-1, xj, xj+1, ..., xN]
# we already have zs because we go row by row with the udf,
# so replace z_features with x of interest
f_index = ordered_features.index(f)
new_value = x_interest[f_index]
x_plus_j[f_index] = new_value
if i > f_i: # we skipped j
x_minus_j[f_index] = new_value
# minus must be first because of lag
return Vectors.dense(x_minus_j), Vectors.dense(x_plus_j)
udf_calculate_x = F.udf(calculate_x, T.ArrayType(VectorUDT()))
# persist before processing
features_df = features_df.persist()
for f in feature_names:
# x column contains x-j and x+j in this order.
# Because lag is calculated this way:
# F.col('anomalyScore') - (F.col('anomalyScore') one row before)
# x-j needs to be first in `x` column array so we should have:
# id1, [x-j row i, x+j row i]
# ...
# that with explode becomes:
# id1, x-j row i
# id1, x+j row i
# ...
# to give us (x+j - x-j) when we calculate marginal contribution
# Note that with explode, x-j and x+j for the same row have the same id
# This gives us the opportunity to use lag with
# a window partitioned by id
x_df = features_df.withColumn('x', udf_calculate_x(
F.lit(f), features_col, features_perm_col
print(f'Calculating SHAP values for "{f}"...')
x_df = x_df.selectExpr(
'id', f'explode(x) as {features_col}'
x_df = model.transform(x_df)
# marginal contribution is calculated using a window and a lag of 1.
# the window is partitioned by id because x+j and x-j for the same row
# will have the same id
x_df = x_df.withColumn(
F.col(column_to_examine) - F.lag(
F.col(column_to_examine), 1
# calculate the average
x_df = x_df.filter(
results[f] =
del x_df
print(f'Marginal Contribution for feature: {f} = {results[f]} ')
ordered_results = sorted(
return ordered_results
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment