Skip to content

Instantly share code, notes, and snippets.

@kawine
Last active October 4, 2019 07:53
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kawine/647747cef4f53ce1896f2a46b1402a61 to your computer and use it in GitHub Desktop.
Save kawine/647747cef4f53ce1896f2a46b1402a61 to your computer and use it in GitHub Desktop.
import numpy as np
from scipy.stats import pearsonr, ttest_ind
from scipy.spatial.distance import cosine
ANALOGY_VOCAB = set([]) # specify your vocabulary
class pair2joint(object):
"""Load co-occurrence counts and calculate PMI and csPMI."""
def __init__(self, fn='counts.txt'):
"""
counts.txt should be of the format `word1 word2 count(word1,word2)' per line.
"""
self.joint = {}
self.marginal = {}
for line in open(fn):
a, b, freq = line.strip().split()
freq = int(freq)
if a in ANALOGY_VOCAB and b in ANALOGY_VOCAB:
self.joint[(a,b)] = freq
self.joint[(b,a)] = freq
self.marginal[a] = self.marginal.get(a,0) + freq
self.marginal[b] = self.marginal.get(b,0) + freq
# total number of word pairs = sum(marginal.values()) / 2
# since there are two entries in self.marginal for each word pair
self.total = sum(self.marginal.values()) / 2.0
def PMI(self, a, b):
return np.log(self.joint[(a,b)]) + np.log(self.total) - np.log(self.marginal[a]) - np.log(self.marginal[b])
def csPMI(self, a, b):
return 2 * np.log(self.joint[(a,b)]) - np.log(self.marginal[a]) - np.log(self.marginal[b])
def __getitem__(self, x):
if len(x) == 2:
return self.joint[x] / self.total
else:
return self.marginal[x] / self.total
FREQUENCIES = pair2joint()
def calc_stats():
"""
Calculate stats for analogy categories given in the format of Mikolov et al. (questions-words.txt).
Word pairs from capital-world, for example, would include ('Paris', 'France'), ('Berlin', 'Germany'), etc.
Statistics are calculated for word pairs from that category.
"""
csPMI_values = {}
PMI_values = {}
joint_counts = {}
for line in open('questions-words.txt'):
if line[0] == ':':
category = line[1:].strip()
csPMI_values[category] = []
PMI_values[category] = []
joint_counts[category] = []
else:
a,b,c,d = line.strip().split()
try:
csPMI_values[category].append(FREQUENCIES.csPMI(a,b))
PMI_values[category].append(FREQUENCIES.PMI(a,b))
joint_counts[category].append(FREQUENCIES.joint[(a,b)])
except KeyError:
pass
try:
csPMI_values[category].append(FREQUENCIES.csPMI(c,d))
PMI_values[category].append(FREQUENCIES.PMI(c,d))
joint_counts[category].append(FREQUENCIES.joint[(c,d)])
except KeyError:
pass
for c1 in csPMI_values:
sample = list(set(csPMI_values[c1]))
print(c1, np.mean(sample), np.mean(PMI_values[c1]), np.median(joint_counts[c1]), np.var(sample))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment