Skip to content

Instantly share code, notes, and snippets.

@lcschv
Last active March 2, 2023 10:06
Show Gist options
  • Save lcschv/75ee379e687c67460955edea28798857 to your computer and use it in GitHub Desktop.
Save lcschv/75ee379e687c67460955edea28798857 to your computer and use it in GitHub Desktop.
from nltk.probability import FreqDist
import math
import pickle
from top2vec import Top2Vec
import numpy as np
from gensim.utils import simple_preprocess
from gensim.parsing.preprocessing import strip_tags
def default_tokenizer(doc):
# This part was copied from Top2Vec tokenizer, if you are using a specific tokenizer you should not use the default one when computing the measure
"""Tokenize documents for training and remove too long/short words"""
return simple_preprocess(strip_tags(doc), deacc=True)
def PWI(model, docs, num_topics=20, num_words=20):
"""
:param model: top2vec model
:param docs: list of strings
:param num_topics: number of topics to use in the computation
:param num_words: number of words to use
:return: PWI value
"""
model.hierarchical_topic_reduction(num_topics)
# This is used to tokenize the data and strip tags (as done in top2vec)
tokenized_data = [default_tokenizer(doc) for doc in docs]
# Computing all the word frequencies
# First I concatenate all the documents and use FreqDist to compute the frequency of each word
word_frequencies = FreqDist(np.concatenate(tokenized_data))
# Computing the frequency of words per document
# Remember to change the tokenizer if you are using a different one to train the model
dict_docs_freqs = {}
for i, doc in enumerate(newsgroups.data):
counter_dict = FreqDist(default_tokenizer(doc))
if i not in dict_docs_freqs:
dict_docs_freqs[i] = counter_dict
PWI = 0.0
p_d = 1 / len(docs)
# This will iterate through the whole dataset and query the topics of each document.
for i, doc in enumerate(docs):
topic_words, word_scores, topic_scores, topic_nums = model.query_topics(query=doc, num_topics=num_topics,
reduced=True)
# Words of the topic
# Topic scores is the topic importance for that document
for words, t_score in zip(topic_words, topic_scores):
for word in words[:num_words]:
if word not in dict_docs_freqs[i]:
# This is added just for some specific cases when we are using different collection to test
continue
# P(d,w) = P(d|w) * p(w)
p_d_given_w = dict_docs_freqs[i].freq(word)
p_w = word_frequencies.freq(word)
p_d_and_w = p_d_given_w * p_w
left_part = p_d_given_w * t_score
PWI += left_part * math.log(p_d_and_w / (p_w * p_d))
return PWI
if __name__ == '__main__':
# Fetching the data for example
from sklearn.datasets import fetch_20newsgroups
newsgroups = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))
# Training the model as presented in the original github repository
model = Top2Vec(documents=newsgroups.data, speed="learn", workers=8)
# Dumping the model
# pickle.dump(model, open('top2vec-20news.pkl', 'wb'))
# Loading model
# model = pickle.load(open('top2vec-20news.pkl', 'rb'))
print("PWI:", PWI(model=model, docs=newsgroups.data, num_topics=20, num_words=20))
@Vela-zz
Copy link

Vela-zz commented Mar 2, 2023

p_d_given_w = dict_docs_freqs[i].freq(word) i just wonder this line is p_w given d not p_d given w

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment