Skip to content

Instantly share code, notes, and snippets.

@xuwangyin
Created June 9, 2016 06:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save xuwangyin/23b49ef28c3a8df11c9a668592a5ff18 to your computer and use it in GitHub Desktop.
Save xuwangyin/23b49ef28c3a8df11c9a668592a5ff18 to your computer and use it in GitHub Desktop.
import scipy
import numpy as np
from collections import Counter
# Kullback–Leibler divergence
# https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence
# http://scipy.github.io/devdocs/generated/scipy.stats.entropy.html
def kl(p, q):
# compute common elements
set_p = set(p)
set_q = set(q)
intersection = set(p).intersection(set(q))
# similarity is 0 when there are no common elements
if len(intersection) == 0:
return 0
# count occurences of common elements
intersection_p = Counter(sorted([e for e in p if e in intersection])).values()
intersection_q = Counter(sorted([e for e in q if e in intersection])).values()
# calculate probability distribution
sum_p = float(sum(intersection_p))
sum_q = float(sum(intersection_q))
intersection_p = [e/sum_p for e in intersection_p]
intersection_q = [e/sum_q for e in intersection_q]
# common elements similarity
intersection_similarity = 1. - scipy.stats.entropy(intersection_p, intersection_q)
# ratio of common elements
area_ratio = float(len(intersection)**2) / (len(set_p)*len(set_q))
# similarity
return intersection_similarity * area_ratio
lines = open('user_content_1_count.txt').readlines()
d = np.ndarray((len(lines),3), np.int32)
for i, line in enumerate(lines):
d[i] = map(int, line.strip().split())
requests = {}
for i in range(d.shape[0]):
key = str(d[i,1])
if not key in requests:
requests[key] = []
for j in range(d[i,0]):
requests[key].append(d[i,2])
similarities = []
for key in requests.keys():
for key2 in requests.keys():
if key != key2:
similarities.append((key, key2, kl(requests[key], requests[key2])))
with open('result.txt', 'w') as f:
for ret in similarities:
f.write(str(ret) + '\n')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment