Skip to content

Instantly share code, notes, and snippets.

@mkaranasou
Last active October 27, 2021 08:29
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save mkaranasou/7aa1f3a28258330679dcab4277c42419 to your computer and use it in GitHub Desktop.
Save mkaranasou/7aa1f3a28258330679dcab4277c42419 to your computer and use it in GitHub Desktop.
How to use Scikit's Isolation Forest in Pyspark - udf and broadcast variables
from pyspark.sql import SparkSession, functions as F, types as T
from sklearn.ensemble import IsolationForest
from sklearn.preprocessing import StandardScaler
np.random.seed(42)
conf = SparkConf()
spark_session = SparkSession.builder \
.config(conf=conf) \
.appName('test') \
.getOrCreate()
# create a dataframe
data = [
{'feature1': 1., 'feature2': 0., 'feature3': 0.3, 'feature4': 0.01},
{'feature1': 10., 'feature2': 3., 'feature3': 0.9, 'feature4': 0.1},
{'feature1': 101., 'feature2': 13., 'feature3': 0.9, 'feature4': 0.91},
{'feature1': 111., 'feature2': 11., 'feature3': 1.2, 'feature4': 1.91},
]
df = spark_session.createDataFrame(data)
df.show()
# instantiate a scaler, an isolation forest classifier and convert the data into the appropriate form
scaler = StandardScaler()
classifier = IsolationForest(contamination=0.3, random_state=42, n_jobs=-1)
x_train = [list(n.values()) for n in data]
# fit on the data
x_train = scaler.fit_transform(x_train)
clf = classifier.fit(x_train)
# broadcast the scaler and the classifier objects
# remember: broadcasts work well for relatively small objects
SCL = spark_session.sparkContext.broadcast(scaler)
CLF = spark_session.sparkContext.broadcast(clf)
def predict_using_broadcasts(feature1, feature2, feature3, feature4):
"""
Scale the feature values and use the model to predict
:return: 1 if normal, -1 if abnormal 0 if something went wrong
"""
prediction = 0
x_test = [[feature1, feature2, feature3, feature4]]
try:
x_test = SCL.value.transform(x_test)
prediction = CLF.value.predict(x_test)[0]
except ValueError:
import traceback
traceback.print_exc()
print('Cannot predict:', x_test)
return int(prediction)
udf_predict_using_broadcasts = F.udf(predict_using_broadcasts, T.IntegerType())
df = df.withColumn(
'prediction',
udf_predict_using_broadcasts('feature1', 'feature2', 'feature3', 'feature4')
)
df.show()
+--------+--------+--------+--------+
|feature1|feature2|feature3|feature4|
+--------+--------+--------+--------+
| 1.0| 0.0| 0.3| 0.01|
| 10.0| 3.0| 0.9| 0.1|
| 101.0| 13.0| 0.9| 0.91|
| 111.0| 11.0| 1.2| 1.91|
+--------+--------+--------+--------+
+--------+--------+--------+--------+----------+
|feature1|feature2|feature3|feature4|prediction|
+--------+--------+--------+--------+----------+
| 1.0| 0.0| 0.3| 0.01| -1|
| 10.0| 3.0| 0.9| 0.1| 1|
| 101.0| 13.0| 0.9| 0.91| 1|
| 111.0| 11.0| 1.2| 1.91| 1|
+--------+--------+--------+--------+----------+
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment