Last active
October 27, 2021 08:29
-
-
Save mkaranasou/7aa1f3a28258330679dcab4277c42419 to your computer and use it in GitHub Desktop.
How to use Scikit's Isolation Forest in Pyspark - udf and broadcast variables
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pyspark.sql import SparkSession, functions as F, types as T | |
from sklearn.ensemble import IsolationForest | |
from sklearn.preprocessing import StandardScaler | |
np.random.seed(42) | |
conf = SparkConf() | |
spark_session = SparkSession.builder \ | |
.config(conf=conf) \ | |
.appName('test') \ | |
.getOrCreate() | |
# create a dataframe | |
data = [ | |
{'feature1': 1., 'feature2': 0., 'feature3': 0.3, 'feature4': 0.01}, | |
{'feature1': 10., 'feature2': 3., 'feature3': 0.9, 'feature4': 0.1}, | |
{'feature1': 101., 'feature2': 13., 'feature3': 0.9, 'feature4': 0.91}, | |
{'feature1': 111., 'feature2': 11., 'feature3': 1.2, 'feature4': 1.91}, | |
] | |
df = spark_session.createDataFrame(data) | |
df.show() | |
# instantiate a scaler, an isolation forest classifier and convert the data into the appropriate form | |
scaler = StandardScaler() | |
classifier = IsolationForest(contamination=0.3, random_state=42, n_jobs=-1) | |
x_train = [list(n.values()) for n in data] | |
# fit on the data | |
x_train = scaler.fit_transform(x_train) | |
clf = classifier.fit(x_train) | |
# broadcast the scaler and the classifier objects | |
# remember: broadcasts work well for relatively small objects | |
SCL = spark_session.sparkContext.broadcast(scaler) | |
CLF = spark_session.sparkContext.broadcast(clf) | |
def predict_using_broadcasts(feature1, feature2, feature3, feature4): | |
""" | |
Scale the feature values and use the model to predict | |
:return: 1 if normal, -1 if abnormal 0 if something went wrong | |
""" | |
prediction = 0 | |
x_test = [[feature1, feature2, feature3, feature4]] | |
try: | |
x_test = SCL.value.transform(x_test) | |
prediction = CLF.value.predict(x_test)[0] | |
except ValueError: | |
import traceback | |
traceback.print_exc() | |
print('Cannot predict:', x_test) | |
return int(prediction) | |
udf_predict_using_broadcasts = F.udf(predict_using_broadcasts, T.IntegerType()) | |
df = df.withColumn( | |
'prediction', | |
udf_predict_using_broadcasts('feature1', 'feature2', 'feature3', 'feature4') | |
) | |
df.show() | |
+--------+--------+--------+--------+ | |
|feature1|feature2|feature3|feature4| | |
+--------+--------+--------+--------+ | |
| 1.0| 0.0| 0.3| 0.01| | |
| 10.0| 3.0| 0.9| 0.1| | |
| 101.0| 13.0| 0.9| 0.91| | |
| 111.0| 11.0| 1.2| 1.91| | |
+--------+--------+--------+--------+ | |
+--------+--------+--------+--------+----------+ | |
|feature1|feature2|feature3|feature4|prediction| | |
+--------+--------+--------+--------+----------+ | |
| 1.0| 0.0| 0.3| 0.01| -1| | |
| 10.0| 3.0| 0.9| 0.1| 1| | |
| 101.0| 13.0| 0.9| 0.91| 1| | |
| 111.0| 11.0| 1.2| 1.91| 1| | |
+--------+--------+--------+--------+----------+ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment