Skip to content

Instantly share code, notes, and snippets.

@hav4ik
Last active August 9, 2023 22:32
Show Gist options
  • Save hav4ik/100aa247eff4d3075db4f8314461f4c2 to your computer and use it in GitHub Desktop.
Save hav4ik/100aa247eff4d3075db4f8314461f4c2 to your computer and use it in GitHub Desktop.
Visualization scripts for "At the Core of a Search Engine: Learning to Rank" blog post.
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
sns.set()
def ndcg_at_k(ranked_relevances, k=None):
"""
Calculate NDCG for given ranked relevances.
"""
dcg = lambda rel, i: rel / np.log2(i + 2) # i + 2 because of Python's 0-indexing
k = k or len(ranked_relevances)
actual_dcg = sum(dcg(rel, i) for i, rel in enumerate(ranked_relevances[:k]))
ideal_dcg = sum(dcg(rel, i) for i, rel in enumerate(sorted(ranked_relevances, reverse=True)[:k]))
return actual_dcg / ideal_dcg
def display_ranking(ranking, ax, cut_off=5):
"""
Display horizontal bars with given ranking relevances.
"""
n = len(ranking)
# Create a color palette that ranges from green (most relevant) to red (least relevant)
palette = sns.color_palette("crest", len(set(ranking)))
# Plot
displayed_ranking = np.array(ranking) + 2.5
bars = sns.barplot(x=displayed_ranking, y=list(range(n)), orient="h", palette=np.array(palette)[::-1][ranking], ax=ax)
# Label bars directly
for idx, rect in enumerate(bars.patches):
bars.text(rect.get_x() + rect.get_width()/2., idx, f'Relevance: {ranking[idx]}', ha='center', va='center', fontsize=10, color='white')
# Add cut-off line
ax.axhline(cut_off-0.5, color='red', linestyle='--') # The '-0.5' aligns the line between bars
ax.text(displayed_ranking.max(), cut_off-0.7, "cut-off at T", va='center', color='red', ha='right')
# Settings for display
ax.set_title(f'NDCG@{len(ranking)}: {ndcg_at_k(ranking, k=cut_off):.4f}')
ax.get_yaxis().set_visible(False) # Hide y-axis
ax.get_xaxis().set_visible(False) # Hide x-axis
# Example usage
ideal_ranking = np.array([5, 4, 4, 3, 2, 1, 0, 0, 0, 0], dtype=np.int)
worst_ranking = ideal_ranking[::-1]
random_ranking = ideal_ranking.copy()
np.random.shuffle(random_ranking)
rankings = [
ideal_ranking,
random_ranking,
worst_ranking
]
fig, axes = plt.subplots(1, 3, figsize=(12, 5)) # Create 3 subplots side by side
for i, ranking in enumerate(rankings):
display_ranking(ranking, ax=axes[i])
plt.tight_layout() # Adjust the spacing between plots
plt.show()
@hav4ik
Copy link
Author

hav4ik commented Aug 9, 2023

Everything is written with the help of ChatGPT :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment