Created
May 21, 2024 07:28
-
-
Save do-me/386fc8ff512c7ff070d77cdb29bee53e to your computer and use it in GitHub Desktop.
Cosine similarity with nan checks
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from numpy.linalg import norm | |
import numpy as np | |
# Define the cosine similarity function with automatic list-to-array conversion | |
def cos_sim(a, b): | |
# Check if either input is NaN, empty, or contains empty strings | |
if a is None or b is None or not a or not b: | |
return np.nan | |
if isinstance(a, list) and any(x == "" or x is None for x in a): | |
return np.nan | |
if isinstance(b, list) and any(x == "" or x is None for x in b): | |
return np.nan | |
# Convert lists to NumPy arrays if they are not already | |
if isinstance(a, list): | |
a = np.array(a) | |
if isinstance(b, list): | |
b = np.array(b) | |
# Check if the arrays contain NaN values | |
if np.isnan(a).any() or np.isnan(b).any(): | |
return np.nan | |
# Calculate cosine similarity | |
return np.dot(a, b) / (norm(a) * norm(b)) | |
# Example usage | |
print(cos_sim([1, 2, 3], [4, 5, 6])) # Should return a valid cosine similarity value | |
print(cos_sim([np.nan, 2, 3], [4, 5, 6])) # Should return np.nan | |
print(cos_sim([], [4, 5, 6])) # Should return np.nan | |
print(cos_sim(None, [4, 5, 6])) # Should return np.nan | |
print(cos_sim(["", 2, 3], [4, 5, 6])) # Should return np.nan |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment