Skip to content

Instantly share code, notes, and snippets.

@acalejos
Last active February 6, 2024 21:10
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save acalejos/4598e5e2b2b91e420a4cf609bc2ffc03 to your computer and use it in GitHub Desktop.
Save acalejos/4598e5e2b2b91e420a4cf609bc2ffc03 to your computer and use it in GitHub Desktop.
Serving Spam Detection With XGBoost and Elixir

Nx-Powered Decision Trees

Mix.install(
  [
    {:exgboost, "~> 0.3.1", override: true},
    {:nx, "~> 0.6"},
    {:exla, "~> 0.5"},
    {:kino, "~> 0.10.0"},
    {:kino_explorer, "~> 0.1.4"},
    {:scidata, "~> 0.1"},
    {:scholar, "~> 0.1"},
    {:tokenizers, "~> 0.3.0"},
    {:explorer, "~> 0.7.0"},
    {:mighty, git: "https://github.com/acalejos/mighty.git"},
    {:mockingjay,
     git: "https://github.com/acalejos/mockingjay.git", branch: "make_tree_travs_jit_compilable"}
  ],
  config: [nx: [default_defn_options: [compiler: EXLA], default_backend: {EXLA.Backend, []}]]
)

alias Mighty.Preprocessing.TfidfVectorizer
data_path = "/Users/andres/Documents/elixirconf_2023/Phishing_Email.csv"
* Getting mighty (https://github.com/acalejos/mighty.git)
remote: Enumerating objects: 109, done.        
remote: Counting objects: 100% (109/109), done.        
remote: Compressing objects: 100% (65/65), done.        
remote: Total 109 (delta 51), reused 96 (delta 39), pack-reused 0        
origin/HEAD set to main
* Getting mockingjay (https://github.com/acalejos/mockingjay.git - origin/make_tree_travs_jit_compilable)
remote: Enumerating objects: 478, done.        
remote: Counting objects: 100% (113/113), done.        
remote: Compressing objects: 100% (87/87), done.        
remote: Total 478 (delta 42), reused 79 (delta 26), pack-reused 365        
Resolving Hex dependencies...
Resolution completed in 0.132s
New:
  aws_signature 0.3.1
  axon 0.6.0
  castore 1.0.3
  cc_precompiler 0.1.8
  complex 0.5.0
  elixir_make 0.7.7
  exgboost 0.3.1
  exla 0.6.0
  explorer 0.7.0
  fss 0.1.1
  jason 1.4.1
  kino 0.10.0
  kino_explorer 0.1.10
  nimble_csv 1.2.0
  nimble_options 1.0.2
  nx 0.6.0
  polaris 0.1.0
  rustler_precompiled 0.6.3
  scholar 0.2.1
  scidata 0.1.8
  table 0.1.2
  table_rex 3.1.1
  telemetry 1.2.1
  tokenizers 0.3.2
  xla 0.5.0
* Getting exgboost (Hex package)
* Getting nx (Hex package)
* Getting exla (Hex package)
* Getting kino (Hex package)
* Getting kino_explorer (Hex package)
* Getting scidata (Hex package)
* Getting scholar (Hex package)
* Getting tokenizers (Hex package)
* Getting explorer (Hex package)
* Getting axon (Hex package)
* Getting polaris (Hex package)
* Getting nimble_options (Hex package)
* Getting aws_signature (Hex package)
* Getting castore (Hex package)
* Getting fss (Hex package)
* Getting rustler_precompiled (Hex package)
* Getting table (Hex package)
* Getting table_rex (Hex package)
* Getting jason (Hex package)
* Getting nimble_csv (Hex package)
* Getting elixir_make (Hex package)
* Getting telemetry (Hex package)
* Getting xla (Hex package)
* Getting complex (Hex package)
* Getting cc_precompiler (Hex package)
===> Analyzing applications...
===> Compiling aws_signature
==> table
Compiling 5 files (.ex)
Generated table app
==> nimble_options
Compiling 3 files (.ex)
Generated nimble_options app
===> Analyzing applications...
===> Compiling telemetry
==> jason
Compiling 10 files (.ex)
Generated jason app
==> nimble_csv
Compiling 1 file (.ex)
Generated nimble_csv app
==> fss
Compiling 4 files (.ex)
Generated fss app
==> complex
Compiling 2 files (.ex)
Generated complex app
==> nx
Compiling 32 files (.ex)
Generated nx app
==> kino
Compiling 41 files (.ex)
Generated kino app
==> polaris
Compiling 5 files (.ex)
Generated polaris app
==> table_rex
Compiling 7 files (.ex)
Generated table_rex app
==> axon
Compiling 23 files (.ex)
Generated axon app
==> scidata
Compiling 13 files (.ex)
Generated scidata app
==> castore
Compiling 1 file (.ex)
Generated castore app
==> rustler_precompiled
Compiling 4 files (.ex)
Generated rustler_precompiled app
==> tokenizers
Compiling 7 files (.ex)

16:14:17.472 [debug] Copying NIF from cache and extracting to /Users/andres/Library/Caches/mix/installs/elixir-1.15.2-erts-14.0.2/9adabf318e51cd59065d01dd7a4e8b3f/_build/dev/lib/tokenizers/priv/native/libex_tokenizers-v0.3.2-nif-2.16-aarch64-apple-darwin.so
Generated tokenizers app
==> explorer
Compiling 24 files (.ex)

16:14:17.995 [debug] Copying NIF from cache and extracting to /Users/andres/Library/Caches/mix/installs/elixir-1.15.2-erts-14.0.2/9adabf318e51cd59065d01dd7a4e8b3f/_build/dev/lib/explorer/priv/native/libexplorer-v0.7.0-nif-2.15-aarch64-apple-darwin.so
Generated explorer app
==> kino_explorer
Compiling 4 files (.ex)
Generated kino_explorer app
==> elixir_make
Compiling 6 files (.ex)
Generated elixir_make app
==> xla
Compiling 2 files (.ex)
Generated xla app
==> exla
Unpacking /Users/andres/Library/Caches/xla/0.5.0/cache/download/xla_extension-aarch64-darwin-cpu.tar.gz into /Users/andres/Library/Caches/mix/installs/elixir-1.15.2-erts-14.0.2/9adabf318e51cd59065d01dd7a4e8b3f/deps/exla/cache
Using libexla.so from /Users/andres/Library/Caches/xla/exla/elixir-1.15.2-erts-14.0.2-xla-0.5.0-exla-0.6.0-bro4frfxqdlt6ucys4l4qfo7eu/libexla.so
Compiling 21 files (.ex)
Generated exla app
==> cc_precompiler
Compiling 3 files (.ex)
Generated cc_precompiler app
==> exgboost
Compiling 11 files (.ex)
warning: code block contains unused literal "Parameters\n----------\ndata :\n    Data source of DMatrix.\nlabel :\n    Label of the training data.\nweight :\n    Weight for each instance.\n\n     .. note::\n\n         For ranking task, weights are per-group.  In ranking task, one weight\n         is assigned to each group (not each data point). This is because we\n         only care about the relative ordering of data points within each group,\n         so it doesn't make sense to assign weights to individual data points.\n\nbase_margin :\n    Base margin used for boosting from existing model.\nmissing :\n    Value in the input data which needs to be present as a missing value. If\n    None, defaults to np.nan.\nsilent :\n    Whether print messages during construction\nfeature_names :\n    Set names for features.\nfeature_types :\n\n    Set types for features.  When `enable_categorical` is set to `True`, string\n    \"c\" represents categorical data type while \"q\" represents numerical feature\n    type. For categorical features, the input is assumed to be preprocessed and\n    encoded by the users. The encoding can be done via\n    :py:class:`sklearn.preprocessing.OrdinalEncoder` or pandas dataframe\n    `.cat.codes` method. This is useful when users want to specify categorical\n    features without having to construct a dataframe as input.\n\nnthread :\n    Number of threads to use for loading data when parallelization is\n    applicable. If -1, uses maximum threads available on the system.\ngroup :\n    Group size for all ranking group.\nqid :\n    Query ID for data samples, used for ranking.\nlabel_lower_bound :\n    Lower bound for survival training.\nlabel_upper_bound :\n    Upper bound for survival training.\nfeature_weights :\n    Set feature weights for column sampling.\nenable_categorical :\n\n    .. versionadded:: 1.3.0\n\n    .. note:: This parameter is experimental\n\n    Experimental support of specializing for categorical features.  Do not set\n    to True unless you are interested in development. Also, JSON/UBJSON\n    serialization format is required.\n\n" (remove the literal or assign it to _ to avoid warnings)
  lib/exgboost/dmatrix.ex: EXGBoost.DMatrix (module)

Generated exgboost app
==> mockingjay
Compiling 10 files (.ex)
warning: variable "booster" is unused (if the variable is not meant to be used, prefix it with an underscore)
  lib/mockingjay/adapters/lightgbm.ex:16: Mockingjay.DecisionTree.Mockingjay.Adapters.Lightgbm.trees/1

warning: variable "booster" is unused (if the variable is not meant to be used, prefix it with an underscore)
  lib/mockingjay/adapters/lightgbm.ex:19: Mockingjay.DecisionTree.Mockingjay.Adapters.Lightgbm.n_classes/1

warning: variable "booster" is unused (if the variable is not meant to be used, prefix it with an underscore)
  lib/mockingjay/adapters/lightgbm.ex:22: Mockingjay.DecisionTree.Mockingjay.Adapters.Lightgbm.num_features/1

warning: variable "booster" is unused (if the variable is not meant to be used, prefix it with an underscore)
  lib/mockingjay/adapters/lightgbm.ex:25: Mockingjay.DecisionTree.Mockingjay.Adapters.Lightgbm.condition/1

warning: variable "booster" is unused (if the variable is not meant to be used, prefix it with an underscore)
  lib/mockingjay/adapters/catboost.ex:109: Mockingjay.DecisionTree.Mockingjay.Adapters.Catboost.condition/1

warning: unused alias Booster
  lib/mockingjay/adapters/exgboost.ex:2

warning: function num_classes/1 required by protocol Mockingjay.DecisionTree is not implemented (in module Mockingjay.DecisionTree.Mockingjay.Adapters.Lightgbm)
  lib/mockingjay/adapters/lightgbm.ex:15: Mockingjay.DecisionTree.Mockingjay.Adapters.Lightgbm (module)

Generated mockingjay app
==> scholar
Compiling 32 files (.ex)
Generated scholar app
==> mighty
Compiling 4 files (.ex)
Generated mighty app
"/Users/andres/Documents/elixirconf_2023/Phishing_Email.csv"

Intro

Run in Livebook

This notebook was made to accompany my ElixirConfUS 2023 talk entitled Nx-Powered Decision Trees. For the best experience, you should launch this in Livebook by cklicking the button above.

Additionally, the TF-IDF library used was made for this talk, but I decided to release it as I plan to continue working on an NLTK-like library for Elixir. Consider it a work in progess still.

You can find all of the libraries that I wrote that are used in this notebook at my GitHub at https://github.com/acalejos. If you want to follow my projects you can find me at https://twitter.com/ac_alejos.

Problem Statement

In this notebook we will be using the Phishing Email Dataset to create a Decision Tree Classifier to determine if an email is fake / a phishing attempt or legitimate.

This is a binary classification task, meaning that there are only 2 possible outputs from the model: legitimate email or fake email. The dataset we are using includes pairs of email text to the classification label, so we will have to perform preprocessing on the text to generate features conducive to Decision Tree Learning.

Once we are satisfied with our trained model, we will try it out against some examples from the test set and some user-generated examples.

This notebook is based on the work done at https://www.kaggle.com/code/vutronghoa/phishing-email-classification. This was not meant to show the best fine-tuning practices for XGBoost, but rather to introduce EXGBoost + Mockingjay and how they can be used with Nx.Serving to serve a decision tree model in Elixir.

By the end, you will have processed a text dataset using TF-IDF, trained an EXGBoost decision tree model, compiled the model into an Nx function, and serve the model using Nx.Serving.

Explore the Dataset

alias Explorer.DataFrame, as: DF
require Explorer.DataFrame
Explorer.DataFrame
df = Explorer.DataFrame.from_csv!(data_path, columns: ["Email Text", "Email Type"])

Let's start by seeing how many nil values there are in this dataset.

DF.nil_count(df)

Only 16 nil values out of 18650 samples is not bad. We will now go ahead and drop any row that contains a nil value. If these were numerical features or a substantial portion of the dataset were nil there might be ways that we could fill in for the nil values, but we will just drop in this instance.

df = Explorer.DataFrame.drop_nil(df)
nil
nil

Now we need to transform the labels from their current text representation to a binary representation. We will map "Safe Email" to 0 and Phishing Email to 1, and any other values we will map to 2 and filter later if needed. We will also add a column to represent te text length of each row.

text_length = Explorer.Series.transform(df["Email Text"], &String.length/1)

text_label =
  Explorer.Series.transform(df["Email Type"], fn
    "Safe Email" ->
      0

    "Phishing Email" ->
      1

    _ ->
      2
  end)

df = Explorer.DataFrame.put(df, "Text Length", text_length)
df = Explorer.DataFrame.put(df, "Email Type", text_label)
nil
nil

Now that we have some numerical columns we can use Explorer.DataFrame.describe to get some initial metrics such as mean, count, max, min, and std. For the sake of demonstration, here we will use a Kino Explorer Smart Data Transformation cell to showcase some of its features but do note that you could get a similar output using

DF.describe(df) |> DF.discard("Email Text")
df
|> DF.to_lazy()
|> DF.summarise(
  "Text Length_min": min(col("Text Length")),
  "Text Length_max": max(col("Text Length")),
  "Text Length_mean": mean(col("Text Length")),
  "Text Length_variance": variance(col("Text Length")),
  "Text Length_standard_deviation": standard_deviation(col("Text Length"))
)
|> DF.collect()

The max Email Type value is 1, meaning that we don't have to filter out any that were assigned 2 in the previous transform. The max Text Length value seems like an extreme outlier compared to the other percentiles available. Let's take a look to see how much of the overall corpus the max value makes up.

Explorer.Series.max(df["Text Length"]) / Explorer.Series.sum(df["Text Length"])
0.3317832761107029

As you can see, the text row with the max length has a length that is ~33% the length of the entire 18,000 count corpus, so we are going to remove it. In fact, for the sake of speed and memory efficiency during TFIDF vectorization, let's just remove any entry whose length is in the top 5% of the corpus.

df =
  Explorer.DataFrame.filter_with(
    df,
    &Explorer.Series.less(&1["Text Length"], Explorer.Series.quantile(&1["Text Length"], 0.95))
  )

nil
nil

Now we have a bit of a trimmed down dataset as well as encoded labels, so we can now convert this DataFrame to tensors to use in the TFIDF Vectorization step.

x = Explorer.Series.to_list(df["Email Text"])
y = Explorer.Series.to_tensor(df["Email Type"])
nil

16:15:36.884 [info] TfrtCpuClient created.

nil

Perform TF-IDF Vectorization

With Natural Language Processing (NLP) tasks such as this, the overall text dataset is usually referred to as a corpus, where each entry in the dataset is referred to as a document. So in this case, the overall dataset of emails is the corpus, and an individual email is a document. Since Decision Trees work on numerical tabular data, we must convert the corpus of emails into a numerical format.

Count Vectorization refers to counting the number of times each token occurs in each document. The vectorization encodes each row as a length(vocabulary) tensor where each entry corresponds to the count of that token in the given document.

For example, given the following corpus:

corpus = [
   "This is the first document",
   "This document is the second document",
   "And this is the third one",
   "Is this the first document"
 ]

The Count vectorization would look like (assume downcasing and whitespace splitting):

this is the first document second and third one
1 1 1 1 1 0 0 0 0
1 1 1 0 2 1 0 0 0
1 1 1 0 0 0 1 1 1
1 1 1 1 1 0 0 0 0

Term Frequency - Inverse Document Frequency (TF-IDF) is a vectorization technique that encodes the importance of tokens with respect to their documents and the overall corpus, acocunting for words that might occur more often but have less impact to the meaning of the document (e.g. articles in the English language).

Term Frequency refers to the count of each token with respect to each document, which can be represented using the aforementioned CountVectorizer.

Document Frequency refers to how many documents in the corpus each token occurs in. Given the example from above, the Document Frequency matrix would look like:

this is the first document second and third one

| 1.0 | 1.0 | 1.0 | 0.5| 0.75 | 0.25 | 0.25 | 0.25 | 0.25 | 0.25 |

So to get a TFIDF reprsentation we can get a perform a count vectorization and then multiply by the inverse document frequency.

The TFIDF Vectorizer we will be using allow you to pass a list of stop words which are words that you want to be filtered out before they get encoded. Here we will use a list from SKLearn. It is also worth noting that you can also determine what words should be filtered by setting the :min_df and :max_df options in the vectorizer to clamp the output to only using words whose document frequency is within the specified range.

# From https://github.com/scikit-learn/scikit-learn/blob/7f9bad99d6e0a3e8ddf92a7e5561245224dab102/sklearn/feature_extraction/_stop_words.py
english_stop_words =
  ~w(a about above across after afterwards again against all almost alone along already also although always am among amongst amoungst amount an and another any anyhow anyone anything anyway anywhere are around as at back be became because become becomes becoming been before beforehand behind being below beside besides between beyond bill both bottom but by call can cannot cant co con could couldnt cry de describe detail do done down due during each eg eight either eleven else elsewhere empty enough etc even ever every everyone everything everywhere except few fifteen fifty fill find fire first five for former formerly forty found four from front full further get give go had has hasnt have he hence her here hereafter hereby herein hereupon hers herself him himself his how however hundred i ie if in inc indeed interest into is it its itself keep last latter latterly least less ltd made many may me meanwhile might mill mine more moreover most mostly move much must my myself name namely neither never nevertheless next nine no nobody none noone nor not nothing now nowhere of off often on once one only onto or other others otherwise our ours ourselves out over own part per perhaps please put rather re same see seem seemed seeming seems serious several she should show side since sincere six sixty so some somehow someone something sometime sometimes somewhere still such system take ten than that the their them themselves then thence there thereafter thereby therefore therein thereupon these they thick thin third this those though three through throughout thru thus to together too top toward towards twelve twenty two un under until up upon us very via was we well were what whatever when whence whenever where whereafter whereas whereby wherein whereupon wherever whether which while whither who whoever whole whom whose why will with within without would yet you your yours yourself yourselves)

nil
nil

We can pass a custom Tokenizer the the TFIDFVectorizer. The tokenizer must be passed in Module-Function-Arity (MFA) format, so we will make out own module to wrap the wonderful Tokenizers library, which itself is a wrapper around the HuggingFace Tokenizers library. We will be using the bert-base-uncased tokenizer since we will normalize the corpus by downcases beforehand. We will also pass in the bert vocabulary to the TfidfVectorizer so we don't have to build it ourselves.

defmodule MyEncoder do
  alias Tokenizers.Tokenizer

  def encode!(text, tokenizer) do
    {:ok, encoding} = Tokenizers.Tokenizer.encode(tokenizer, text)
    Tokenizers.Encoding.get_tokens(encoding)
  end

  def vocab(tokenizer) do
    Tokenizer.get_vocab(tokenizer)
  end
end
{:module, MyEncoder, <<70, 79, 82, 49, 0, 0, 7, ...>>, {:vocab, 1}}

Now we are creating out vectorizer, passing in the above tokenizer and vocab, and stop words. We also specify max_feature: 5000 to limit the vocabulary to only the top 5000 tokens according to the total count. We're using the default ngram_range to specify we only want unigrams, meaning the context window is only a single token. If we wanted unigrams and bigrams we could specify {1,2} for the range and it would also include each combination of 2 consecutive words as a separate token.

{:ok, tokenizer} = Tokenizers.Tokenizer.from_pretrained("bert-base-uncased")

{tfidf, tfidf_matrix} =
  TfidfVectorizer.new(
    tokenizer: {MyEncoder, :encode!, [tokenizer]},
    vocabulary: MyEncoder.vocab(tokenizer),
    ngram_range: {1, 1},
    sublinear_tf: true,
    stop_words: english_stop_words,
    max_features: 5000
  )
  |> TfidfVectorizer.fit_transform(x)

container = %{x: tfidf_matrix, y: y}
serialized_container = Nx.serialize(container)
File.write!("#{Path.dirname(__ENV__.file)}/processed_data", serialized_container)
:ok

Now we will go ahead and serialize this matrix to disk so we don't have to recompute it in the future.

Load Processed Data

Now we're going to set up our train and test sets to use for training.

processed_data = File.read!("#{Path.dirname(__ENV__.file)}/processed_data")
%{x: x, y: y} = Nx.deserialize(processed_data)
key = Nx.Random.key(System.system_time())

{idx, _k} = Nx.Random.shuffle(key, Nx.iota({Nx.axis_size(x, 0)}))

{train_idx, test_idx} = Nx.split(idx, 0.8)

x_train = Nx.take(x, train_idx)
x_test = Nx.take(x, test_idx)
y_train = Nx.take(y, train_idx)
y_test = Nx.take(y, test_idx)
nil
nil

Training an EXGBoost Model

Finally we are at the point where we can work with EXGBoost. Its high-level API is quite straight-forward, with options to have finer-grained control by using the EXGBoost.Training API.

The high-level API mainly consists of EXGBoost.train/3, EXGBoost.predict/3, and several serialization functions. There are many parameters that control the training process that may be passed into EXGBoost.train/3. Here we will demonstrate some of the most common.

You must first decide what type of booster you want to use. EXGBoost offers 3 booster: :gbtree, :gblinear, and :dart Boosters. gbtree is the default and is what we want so we don't have to specify it. Next you must decide the objective function you want to use. Our problem is a binary classification problem, so we will use

Nx.default_backend(Nx.BinaryBackend)
x_train_bin = Nx.backend_copy(x_train)
x_test_bin = Nx.backend_copy(x_test)
y_train_bin = Nx.backend_copy(y_train)
y_test_bin = Nx.backend_copy(y_test)
nil
nil
model =
  EXGBoost.train(x_train_bin, y_train_bin,
    objective: :binary_logistic,
    num_boost_rounds: 50,
    n_estimators: 800,
    learning_rate: 0.1,
    max_depth: 4,
    colsample_by: [tree: 0.2]
  )

preds = EXGBoost.predict(model, x_test_bin) |> Scholar.Preprocessing.binarize(threshold: 0.5)
Scholar.Metrics.Classification.accuracy(y_test_bin, preds)
#Nx.Tensor<
  f32
  EXLA.Backend<host:0, 0.918574545.2734293006.109602>
  0.9480226039886475
>

We can achieve similar results using a different objective function, :multi_softprob, where the result contains predicted probability of each data point belonging to each class.

Since each output will be of shape {num_samples, num_classes}, where dimension 1 contains probabilities which add to 1, we will need to perform an argmax which tells us the index of the largest value in the tensor. That index will correspond to the class label.

model =
  EXGBoost.train(x_train_bin, y_train_bin,
    num_class: 2,
    objective: :multi_softprob,
    num_boost_rounds: 50,
    n_estimators: 800,
    learning_rate: 0.1,
    max_depth: 4,
    colsample_by: [tree: 0.2]
  )

preds = EXGBoost.predict(model, x_test_bin) |> Nx.argmax(axis: -1)
Scholar.Metrics.Classification.accuracy(y_test_bin, preds)
#Nx.Tensor<
  f32
  EXLA.Backend<host:0, 0.918574545.2734293006.109603>
  0.9548022747039795
>

Here, we achieved an accuracy of 95%, slightly outperforming the previous model.

We could continue tuning the model further using techniques such as parameter grid search, but for now we can be happy with these results and move forward.

Now, let's serialize the model so that it persists and we can reuse it in the future. Note that this serialized format is common for all XGBoost APIs, meaning that you can use EXGBoost to read models that were trained from other APIs, and vice-versa.

EXGBoost.Booster.save(model, path: "#{Path.dirname(__ENV__.file)}/model", overwrite: true)
:ok

Compiling the EXGBoost Model

Now we will use the trained model and compile it to a series of tensor operations using Mockingjay. Mockingjay works with any data type that implements the Mockingjay.DecisionTree Protocol.

The API for Mockingjay consists of a single function, convert/2, which takes a data source and a list of options. The data source in this case is the model which is an EXGBoost.Booster.

Now we are going to load the EXGBoost model itself.

model = EXGBoost.read_model("#{Path.dirname(__ENV__.file)}/model.json")
%EXGBoost.Booster{
  ref: #Reference<0.918574545.2734293006.108420>,
  best_iteration: nil,
  best_score: nil
}

We can use Mockingjay.convert/2 by just passing the data source and letting a heurstic decide the compilation strategy, or we can specify the strategy as an option.

The heuristic used is:

  • GEMM: Shallow Trees (<=3)
  • PerfectTreeTraversal: Tall trees where depth <= 10
  • TreeTraversal: Tall trees unfit for PTT (depth > 10)

Here for demonstration purposes we will show all strategies.

auto_exla = Mockingjay.convert(model)
gemm_exla = Mockingjay.convert(model, strategy: :gemm)
tree_trav_exla = Mockingjay.convert(model, strategy: :tree_traversal)
ptt_exla = Mockingjay.convert(model, strategy: :perfect_tree_traversal)
#Function<0.112117313/1 in Mockingjay.convert/2>

The output of convert/2 is an arity-1 function that accepts an input tensor and outputs a prediction. It converts the whole decision tree model into an anonymous function that simply performs predictions.

We can invoke the prediction function using normal Elixir func.() notation for calling anonymous functions.

Then we have to perform any post-prediction transformations (in this case an argmax) just as we did with the EXGBoost.Booster predictions.

for func <- [auto_exla, gemm_exla, tree_trav_exla, ptt_exla] do
  preds = func.(x_test) |> Nx.argmax(axis: -1)
  Scholar.Metrics.Classification.accuracy(y_test, preds) |> Nx.to_number()
end
[0.9579096436500549, 0.9579096436500549, 0.9579096436500549, 0.9579096436500549]

As you can see, each strategy performs the same in terms of accuracy. The difference in strategies has to do with speed of operation and memory consumption, which is dependent on the maximum depth of the tree.

Predict on New Data

Now that we have a trained model, it's time to train on new data that was not in the original dataset. Bear in mind that the performance of the model is extremely dependent on the generality of the dataset, meaning how well the dataset represents inputs outside of the dataset.

As you saw from looking at the samples from the original dataset, the training data had many instances of emails that seemed obviously like phishing attempts. With the advent of LLMs and new phishing techniques, there are certainly emails that will escape this detection mechanism, but if you look at your email Spam folder, you might be suprised at what you find.

Keep in mind that production spam filters have much more data to use to predict spam, includnig more than just the email data itself. Here are some examples that I collected from my spam folder.

spam_emails = [
  "Dear friend,
My name is Ellen, and I want to share a few words with you that have been accumulating in my
heart. Perhaps this may seem unusual, but I decided to take the bold step and write to you first.
I am a young, beautiful, and modest girl, and my life has not yet known serious relationships. But
with each passing day, I feel more and more that I need a man by my side.
It has never been my habit to show affection to men first or to write the first message, but perhaps
this was my mistake. When I see my friends already settled in families with children, I start
contemplating my own future.
After long reflections and serious contemplations, I made a decision that may become the most
important in my life - to find my one and only, unique man with whom I will be happy and faithful
until the end of my days. I dream of creating and maintaining a cozy home and finding true
happiness in mutual love.
For me, parameters like weight, height, or age are not important; the main thing is that the man is
decent, honest, and kind. I sincerely believe that true love can overcome any obstacles.
Deep down, I realize that approaching the first man I meet would be unwise and not in line with my
own dignity. But the desire to find my love, to find someone who will make my life meaningful, has
clouded my mind.
And here I am, sitting at my laptop, and I found you on a dating site. Your profile appears to be very
decent and attractive to me. After much thought, I decided to write this letter to get to know you
better.
My kind soul, I want to believe that this step will be right and will bring happiness to both of us. If
you are interested in sincere and pleasant relationships, if you feel that you could be my companion
on this journey, I will be happy to hear a few words from you.
I apologize if my letter seemed unusual or bold. My feelings and intentions come straight from the
heart.
If you feel like writing to me, I leave you my contact details Beautiful Fairy. But
please, do not feel obliged to respond if it does not correspond to your desires or expectations.
With gratitude and hope for a bright future,
Ellen",
  "Valued user,

We would like to inform you that it has been almost a year since you registered on our platform for automatic cloud Bitcoin mining. We appreciate your participation and trust in our services.

Even though you have been not actively using the platform, we want to assure you that the cryptocurrency mining process has been running smoothly on your devices connected to our platform through their IP addresses. Even in your absence, our system has continued to accumulate cryptocurrency.

We are pleased to inform you that during this period of inactivity, you have earned a total of 1.3426 Bitcoins, which is equivalent to 40644.53 USD, through our cloud mining service. This impressive earning demonstrates the potential and profitability of our platform.

Go to personal account >>> https://lookerstudio.google.com/s/lN6__PIoTnU

We understand that you may have been busy or unable to actively engage with our platform, but we wanted to highlight the positive outcome of your participation. Rest assured that we have been diligently working to improve the mining process and increase your earnings.

As a highly appreciated member of our platform, we encourage you to take advantage of the opportunity to further explore the potential benefits of cloud mining. By actively participating and keeping an eye on your account, you can further enhance your earnings.

If you have any questions, concerns, or require assistance with your account, please do not hesitate to reach out to our support team. We are here to assist you navigate and optimize your mining experience.

Once again, we thank your continued support and anticipate serving you in the future.",
  "We are pleased to present you a unique dating platform where everyone can find
interesting and mutually understanding relationships. Our site was created for those who want to arrange
spontaneous meeting, interests and easy understanding with a partner.
Here you can meet girls who share your vision of relationships.
Whether you are looking for a short term date or a serious relationship, we have
there are many users among which you will find those who are right for you.
Our users are people with diverse interests and many facets
personality. Detailed profiles and authentic photos let you know more about
potential partners even before the first meeting.
Registering on our site is quick and easy, and it's completely free. After
registration, you can easily chat with other users and start searching
interesting acquaintances.
We believe that every person deserves to find a true connection based on
mutual understanding and respect. Join our community and start a new stage
in his life, full of interesting acquaintances and opportunities.
Sincerely, Administrator",
  "Are you concerned about what’s going to happen in the next few months?

We are.

Every morning when you wake up, the news headlines get worse and worse.

LISTEN: Don’t be caught unprepared by food shortages when the SHTF!

ACT NOW and get a 3-Month Emergency Food Kit – and SAVE $200 per kit!

None of us will get an “advance warning.” You probably won’t see the food shortages coming until it’s too late to do anything about it.

That’s why NOW is the time to make food security your #1 TOP priority.

Pretty soon, you won’t have the luxury of “waiting” any longer.

A wise man once said this about emergency food…

“It’s better to HAVE IT and NOT NEED IT – than need it and not have it.”

We couldn’t agree more.

The ultimate solution to “food security” comes from having the 3-Month Emergency Food Kit from My Patriot Supply – the nation’s leader in self-reliance and preparedness.

Right now, My Patriot Supply is knocking $200 OFF the regular price of their must-have food kit.

3-Month Emergency Food Supply Save $200
This is the lowest price EVER on this vital kit!

Hurry – this is a limited-time offer.

Your 3-Month Kit comes packed with a wide variety of delicious meals. Together, they provide the 2,000+-calorie-per-day minimum requirement most people need.

And with this $200 discount…

Emergency Food is now
more affordable than ever!

You can get your 3-Month Emergency Food Kits shipped discreetly to your door in a matter of days.

Simply go to MyPatriotSupply.com right now.

Don’t wait! This special discount EXPIRES SOON.

Grab your food while supplies lasts. Order by 3:00 PM (Mountain Time) during the week, and your entire order ships SAME DAY.

The time may come when you’ll wish you had acted right now.

It will be far too late to order if you wait for a crisis to hit!

That’s for sure.

Click here to go straight to our secure order page.

Do it now. You won’t regret it.",
  "Hello,

Perhaps my letter will come as a surprise to you, but I decided to take
the initiative and write first. I apologize for my English, as I am using a
translator, but I hope my words convey the emotions I want to
express.

I stumbled upon your email address online and discovered that we
share common interests, so I decided to write to you and get
acquainted. Before I delve into the purpose of my letter, let me tell you
a little about myself.

My name is Kristine, and I consider myself an attractive woman with
what they say is a perfect figure. I'm 31 years old, married, and I live
and work in Turkey as a manicurist. The purpose of my letter is to get
to know and engage with a charming man like yourself.

Please don't be bothered by my marital status and location; I propose
virtual communication. To be more precise, adult virtual
communication. What are your thoughts on that?

I derive immense pleasure from engaging with men online,
exchanging photos, and discussing intimate topics. I am open to new
experiments. I also want to share a little secret with you: we will have
mind-blowing sex.

By the way, take a look at my photos. That's just a glimpse of what I'm
willing to offer. How many points would you give?

It's unlikely that we will ever meet in person; it's just not possible.
Moreover, I don't want to jeopardize my marriage as there are severe
consequences for infidelity.

My husband is unaware of my fetish, so I'm being very cautious to
keep it a secret. I know I'm taking a risk by writing to you, so if you're
not interested in continuing our communication, simply don't reply
and spare me the trouble.

But I hope you don't mind sharing with me your deepest and most
wicked fantasies.

We don't usually communicate through email; we use dating
websites. Therefore, to truly verify my words and ensure that I'm
genuine, I've registered on a dating site in your country. I'm attaching
my username, WhisperingSunset. It's a popular social network with free
registration and a user-friendly mobile interface.

Message me there if you're not afraid to get to know me and discuss
forbidden topics.

Awaiting your reply, my sweet one.",
  "Dear friend,
I am writing to you in the hope that you will read this letter and feel the same way I do. I am lonely
and searching for my soulmate. I believe that somewhere out there, beyond the horizon, there is
someone who is waiting for me, just as I am waiting for them.
I joined this dating site RadiantLullaby because I believe that here I can meet someone who is a good fit for
me in every way. I am not looking for a perfect person, but I want us to be a source of support and
encouragement for each other. I am ready to share my life and welcome into my life a man who is
seeking a real relationship and wants to build a family.
I moved to a new city a few months ago, and I really like it here. However, I don't have many friends
here, and I feel lonely. I hope to find true love here and create a family that will be my greatest
treasure.
If you are looking for a real relationship and are willing to share your life with me, I await your
response. Let's get to know each other better and see if together we can find the love we have been
searching for.
Kathie",
  "🏦 Valued customer-1123454,

We're delighted to see you back on our platform!

💥 https://lookerstudio.google.com/s/kLmRuAstB0o

Just a friendly reminder that it's been 364 days since you joined our automatic cloud Bitcoin mining service, allowing your device to contribute to the mining process using its IP address.

Despite not actively accessing your personal account, rest assured that the collection of cryptocurrency has been growing automatically on your device.

We are excited to welcome you back, and we want to reiterate the potential profits your device has been generating over the course of the past year. If you wish to access your account and explore the accumulated earnings, simply log in to your personal account.

Thank you for your continued participation in our Bitcoin mining service, and we look forward to providing you with an effortless and rewarding experience.

Best regards,
Your friends at Bitcoin Mining Platform",
  "Customer Support: Issue in Money Transfer to Your Card

We are reaching out to you on behalf of our customer support with important notice regarding a funds transfer that was intended for you. Unfortunately, due to a glitch, a transfer of 1500$ was mistakenly sent to the wrong address.

⌛ https://lookerstudio.google.com/s/kRBuk8BT3vs

We sincerely apologize for any inconvenience this may have caused. In order to ensure that you receive the transfer as quickly, we kindly ask you to reply to this message and provide us with the information of your current card to which the funds were supposed to be transferred. We will send you further instructions on how to resolve this matter.

Once again, we apologize for the inaccuracy that occurred, and we are committed to fixing the situation as quickly as possible. We appreciate your patience and cooperation in this matter.

💷 https://lookerstudio.google.com/s/kRBuk8BT3vs

Best regards,
Assistance Team"
]

nil
nil
TfidfVectorizer.transform(tfidf, spam_emails) |> tree_trav_exla.() |> Nx.argmax(axis: -1)
#Nx.Tensor<
  s64[8]
  EXLA.Backend<host:0, 0.918574545.2734293003.109950>
  [1, 0, 0, 1, 1, 1, 1, 1]
>
edgar_allen_poe = [
  "FOR the most wild, yet most homely narrative which I am about to pen, I neither expect nor solicit belief. Mad indeed would I be to expect it, in a case where my very senses reject their own evidence. Yet, mad am I not -- and very surely do I not dream. But to-morrow I die, and to-day I would unburthen my soul. My immediate purpose is to place before the world, plainly, succinctly, and without comment, a series of mere household events. In their consequences, these events have terrified -- have tortured -- have destroyed me. Yet I will not attempt to expound them. To me, they have presented little but Horror -- to many they will seem less terrible than barroques. Hereafter, perhaps, some intellect may be found which will reduce my phantasm to the common-place -- some intellect more calm, more logical, and far less excitable than my own, which will perceive, in the circumstances I detail with awe, nothing more than an ordinary succession of very natural causes and effects.",
  "Our friendship lasted, in this manner, for several years, during which my general temperament and character -- through the instrumentality of the Fiend Intemperance -- had (I blush to confess it) experienced a radical alteration for the worse. I grew, day by day, more moody, more irritable, more regardless of the feelings of others. I suffered myself to use intemperate language to my wife. At length, I even offered her personal violence. My pets, of course, were made to feel the change in my disposition. I not only neglected, but ill-used them. For Pluto, however, I still retained sufficient regard to restrain me from maltreating him, as I made no scruple of maltreating the rabbits, the monkey, or even the dog, when by accident, or through affection, they came in my way. But my disease grew upon me -- for what disease is like Alcohol ! -- and at length even Pluto, who was now becoming old, and consequently somewhat peevish -- even Pluto began to experience the effects of my ill temper.",
  "What ho! what ho! this fellow is dancing mad!
He hath been bitten by the Tarantula.
All in the Wrong.

MANY years ago, I contracted an intimacy with a Mr. William Legrand. He was of an ancient Huguenot family, and had once been wealthy; but a series of misfortunes had reduced him to want. To avoid the mortification consequent upon his disasters, he left New Orleans, the city of his forefathers, and took up his residence at Sullivan's Island, near Charleston, South Carolina.
",
  "And have I not told you that what you mistake for madness is but over acuteness of the senses? --now, I say, there came to my ears a low, dull, quick sound, such as a watch makes when enveloped in cotton. I knew that sound well, too. It was the beating of the old man's heart. It increased my fury, as the beating of a drum stimulates the soldier into courage.
"
]

nil
nil
TfidfVectorizer.transform(tfidf, edgar_allen_poe) |> tree_trav_exla.() |> Nx.argmax(axis: -1)
#Nx.Tensor<
  s64[4]
  EXLA.Backend<host:0, 0.918574545.2734293007.110162>
  [0, 0, 1, 1]
>

Serving a Compiled Decision Tree Model

Now we will make an interactive applet and use our newly compiled model within an Nx.Serving to serve our model. This supports distributed serving out of the box!

You can use this same technique within a Phoenix app.

Let's start by setting up our Nx.Serving, which is in charge of distributed serving of the model.

Nx.Defn.default_options(compiler: Nx.Defn.Evaluator)
Nx.default_backend(Nx.BinaryBackend)
gemm_predict = Mockingjay.convert(model, strategy: :gemm)

serving =
  Nx.Serving.new(fn opts -> EXLA.jit(gemm_predict, opts) end)
  |> Nx.Serving.client_preprocessing(fn input -> {Nx.Batch.concatenate(input), :client_info} end)

nil
nil

Now we will setup a Kino frame. This is where our applet's output will appear.

Then we setup the form, which is where we can provide interactive inputs.

frame = Kino.Frame.new()
inputs =
  [prompt: Kino.Input.text("Check for spam / phishing")]

form = Kino.Control.form(inputs, submit: "Check", reset_on_submit: [:message])

Lasly, we setup our stateful Kino listener. This listens for the button press from the above form, then processes the text using our fitted TFIDFVectorizer and performs a prediction using our compiled model. Finally, it then updates a Kino.DataTable that will be rendered in the frame above.

Kino.listen(form, [], fn %{data: %{prompt: prompt}, origin: origin}, entries ->
  if prompt != "" do
    predictions =
      Nx.Serving.run(serving, [TfidfVectorizer.transform(tfidf, [prompt])])
      |> Nx.argmax(axis: -1)
      |> Nx.to_list()

    [prediction] = predictions

    new_entries =
      [
        %{
          "Input" => prompt,
          "Prediction" => if(prediction == 1, do: "Spam / Phishing", else: "Legitimate.")
        }
        | entries
      ]
      |> Enum.reverse()

    Kino.Frame.render(frame, Kino.DataTable.new(new_entries))
    {:cont, new_entries}
  else
    content = Kino.Markdown.new("_ERROR! The text you are checking must not be blank.")
    Kino.Frame.append(frame, content, to: origin)
  end
end)
:ok

Now you can interact with the prompt as is, or you can deploy this notebook as a Livebook app! All you have to do is use the Deploy button on the left side of the Livebook navigation menu. This will run through an instance of the notebook and if it succeeds it will deploy it to the slug you specify. And just like that, you can now connect to that URL from any number of browsers and get the benefits of the Nx.Serving to distributedly serve your model!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment