Skip to content

Instantly share code, notes, and snippets.

@GarrettMooney
Last active October 25, 2022 23:45
Show Gist options
  • Save GarrettMooney/f70e6b69988ffde7c821311bc7131e59 to your computer and use it in GitHub Desktop.
Save GarrettMooney/f70e6b69988ffde7c821311bc7131e59 to your computer and use it in GitHub Desktop.
%pip install "whatlies[sentence_tfm]"  # quotes for my fellow zsh users

import numpy as np
from whatlies.language import SentenceTFMLanguage
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression

pipe = Pipeline([
    ("embed", SentenceTFMLanguage('distilbert-base-nli-stsb-mean-tokens')),
    ("model", LogisticRegression())
])

X = [
    "i really like this post",
    "thanks for that comment",
    "i enjoy this friendly forum",
    "this is a bad post",
    "i dislike this article",
    "this is not well written"
]

y = np.array([1, 1, 1, 0, 0, 0])

pipe.fit(X, y)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment