Created
December 25, 2023 07:36
-
-
Save Katsumata420/8bcf27f566b616204c8ba035e00ae227 to your computer and use it in GitHub Desktop.
Jaccard Score for NLP (with Weighted)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from typing import List | |
from collections import defaultdict | |
import fugashi | |
import ipadic | |
class JaccardNLP: | |
def __init__(self): | |
dic_dir = ipadic.DICDIR | |
mecabrc = os.path.join(dic_dir, "mecabrc") | |
mecab_option = f'-d "{dic_dir}" -r "{mecabrc}" -Owakati' | |
self.tokenizer = fugashi.GenericTagger(mecab_option) | |
def tokenize(self, text: str) -> List[str]: | |
return self.tokenizer.parse(text).split() | |
def similarity(self, corpus_a: List[str], corpus_b: List[str]) -> float: | |
"""corpus-a, b 間の類似度を計算する | |
Args: | |
corpus_a (List[str]): 複数の文書のリスト | |
corpus_b (List[str]): 複数の文書のリスト | |
""" | |
word_a = set() | |
for doc in corpus_a: | |
word_a.update(set(self.tokenize(doc))) | |
word_b = set() | |
for doc in corpus_b: | |
word_b.update(set(self.tokenize(doc))) | |
intersection = word_a & word_b | |
union = word_a | word_b | |
return len(intersection) / len(union) | |
class WeightedJaccardNLP(JaccardNLP): | |
"""https://arxiv.org/pdf/2305.10703.pdf""" | |
def __init__(self): | |
super().__init__() | |
def similarity(self, corpus_a: List[str], corpus_b: List[str]) -> float: | |
"""weighted Jaccard Similarity の計算 | |
a と b のどちらにも現れる単語のみを考慮する | |
その中で、a と b の単語の出現頻度の差を考慮する | |
""" | |
# key: word, value: count | |
word_a = defaultdict(int) | |
for doc in corpus_a: | |
for word in self.tokenize(doc): | |
word_a[word] += 1 | |
word_b = defaultdict(int) | |
for doc in corpus_b: | |
for word in self.tokenize(doc): | |
word_b[word] += 1 | |
intersection = set(word_a.keys()) & set(word_b.keys()) | |
if len(intersection) == 0: | |
return 0.0 | |
numerator = 0 | |
denominator = 0 | |
for word in intersection: | |
numerator += min(word_a[word], word_b[word]) | |
denominator += max(word_a[word], word_b[word]) | |
return numerator / denominator | |
def main(): | |
corpus_a = ["今日はクリスマスです", "今日は晴れです", "今日は雨です", "今日は雨です"] | |
corpus_b = ["今日は雨です", "今日は雨です", "今日は雨です"] | |
jaccard_type = "normal" | |
jaccard_type = "weighted" | |
if jaccard_type == "normal": | |
jaccard = JaccardNLP() | |
elif jaccard_type == "weighted": | |
jaccard = WeightedJaccardNLP() | |
else: | |
raise ValueError("Invalid jaccard_type") | |
score = jaccard.similarity(corpus_a, corpus_b) | |
print(score) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment