Last active
March 6, 2020 23:06
-
-
Save epeters3/d29627505c8515116549f2415a8f5d5f to your computer and use it in GitHub Desktop.
This is a naive system for intent recognition, for use in dialogue systems. Other intents can be added. This system is naive because it makes the assumption that the sentence embeddings of all natural language instantiations of an intent will be within two standard deviations of the intent's examples centroid. This is a heuristic.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from statistics import mean | |
from math import sqrt | |
from functools import reduce | |
from sentence_transformers import SentenceTransformer | |
import numpy as np | |
from scipy.spatial.distance import cosine | |
model = SentenceTransformer("distilbert-base-nli-stsb-mean-tokens") | |
# A very simple intent engine. Recognize affirmative and | |
# negative responses using semantic similarity determined | |
# using sentence embeddings. | |
intents = { | |
"yes": { | |
"examples": [ | |
"absolutely", | |
"of course", | |
"ok", | |
"sure", | |
"That's right", | |
"Why not", | |
"ya", | |
"yes", | |
"yes i am", | |
"yes i think so", | |
"Yes.", | |
"I would love to." | |
] | |
}, | |
"no": { | |
"examples": [ | |
"I don't think so", | |
"nah.", | |
"no", | |
"No I don't", | |
"no I'm not", | |
"No thank you.", | |
"No way.", | |
"Nope.", | |
] | |
}, | |
} | |
def determine_intent(text: str, k: float) -> str: | |
# Measure how close `text` is to each of the | |
# intent's centroids. | |
embedding = np.mean(model.encode(text), axis=0) | |
comparisons = [] | |
for intent, data in intents.items(): | |
dist_to_centroid = cosine(embedding, data["embedding"]) | |
comparisons.append( | |
{ | |
"intent": intent, | |
"dist": dist_to_centroid, | |
"is_within_k_std": dist_to_centroid <= k*data["embedding_std"], | |
} | |
) | |
# Find the intent closest to `text`. | |
closest = reduce( | |
lambda acc, x: x if x["dist"] < acc["dist"] else acc, | |
comparisons, | |
{"dist": float("inf")}, | |
) | |
# Return the intent if `text` is close enough to its centroid. | |
if closest["is_within_k_std"]: | |
return closest["intent"] | |
else: | |
return "unknown" | |
if __name__ == "__main__": | |
# Create a centroid and standard deviation for each intent | |
# based on its examples. | |
for intent, data in intents.items(): | |
ex_embeddings = [np.mean(model.encode(ex), axis=0) for ex in data["examples"]] | |
centroid = np.mean(ex_embeddings, axis=0) | |
data["embedding"] = centroid | |
data["embedding_std"] = sqrt( | |
mean(cosine(emb, centroid) ** 2 for emb in ex_embeddings) | |
) | |
print(f"std of '{intent}':", data["embedding_std"]) | |
# Accept user input and output the system's approximation of their desired intent. | |
user_in = "" | |
exit_resp = {"q", "Q", "quit", "exit"} | |
while user_in not in exit_resp: | |
user_in = input("Please enter your intent (press 'q' to quit): ") | |
if user_in not in exit_resp: | |
print("your intent is:", determine_intent(user_in, 2.0)) | |
else: | |
print("goodbye.") | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment