Skip to content

Instantly share code, notes, and snippets.

@mooreniemi
Created January 29, 2024 03:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mooreniemi/70e26db0d892e5fdd3d67e56a2f0a06f to your computer and use it in GitHub Desktop.
Save mooreniemi/70e26db0d892e5fdd3d67e56a2f0a06f to your computer and use it in GitHub Desktop.
example of ia-select alg
use std::collections::HashMap;
use std::time::Instant;
/**
The IA-Select algorithm translated into Rust from
[Diversifying Search Results](https://www.microsoft.com/en-us/research/wp-content/uploads/2009/02/diversifying-wsdm09.pdf)
which is a "(1 − 1/`e`)-approximation algorithm for `Diversify(k)`." `Diversify(k)` is an NP-hard function that maximizes
the probability that you find a subset of documents from multiple categories of documents that "satisfies the average user."
IA stands for "intent aware" but intents are then mapped to categories.
*/
fn ia_select<'a>(
num_documents_to_select: usize,
categories: &Vec<&str>,
doc_ids: &'a Vec<&'a str>,
probability_category_given_query: &HashMap<&str, f64>,
doc_id_category_score: &HashMap<(&str, &str), f64>,
) -> Vec<&'a str> {
let mut selected_doc_ids = Vec::new();
let mut doc_ids = doc_ids.clone();
let mut conditional_probability = probability_category_given_query.clone();
while selected_doc_ids.len() < num_documents_to_select && !doc_ids.is_empty() {
let mut highest_marginal_utility = f64::MIN;
let mut candidate_doc_id = "";
for &doc_id in &doc_ids {
let marginal_utility: f64 = categories
.iter()
.map(|&category| {
conditional_probability.get(category).unwrap_or(&0.0)
* doc_id_category_score
.get(&(doc_id, category))
.unwrap_or(&0.0)
})
.sum();
if marginal_utility > highest_marginal_utility {
highest_marginal_utility = marginal_utility;
candidate_doc_id = doc_id;
}
}
// we selected a document, so we assume that we may have satisfied the query
// given that, further documents have diminishing returns so we set conditional probability downwards
if !candidate_doc_id.is_empty() {
selected_doc_ids.push(candidate_doc_id);
for &category in categories {
let quality = doc_id_category_score
.get(&(candidate_doc_id, category))
.unwrap_or(&0.0);
let current_probability = conditional_probability.entry(category).or_insert(0.0);
*current_probability *= 1.0 - quality;
}
doc_ids.retain(|&d| d != candidate_doc_id);
}
}
selected_doc_ids
}
fn main() {
// we usually call this k
let num_documents_to_select = 5;
// we often say verticals, or just different retrievers here, depending on how you look at it
// the paper calls them categories so that's what I use here
let categories = vec!["recent", "all"];
let doc_ids = vec![
"doc1", "doc2", "doc3", "doc4", "doc5", "doc6", "doc7", "doc8", "doc9", "doc10",
];
// this is the signal we depend on
let probability_category_given_query: HashMap<&str, f64> =
[("recent", 0.9), ("all", 0.1)].iter().cloned().collect();
// the ranking given by different rankers, these need to be normalized somehow
// otherwise a ranker with a higher overall distribution will dominate
// in the paper this is the "quality value" per category
let doc_id_category_score: HashMap<(&str, &str), f64> = [
// only in recent
(("doc1", "recent"), 0.5),
(("doc1", "all"), 0.0),
// in both but more relevant in all
(("doc2", "recent"), 0.2),
(("doc2", "all"), 0.4),
// only in recent, the highest recent
(("doc3", "recent"), 0.8),
(("doc3", "all"), 0.0),
// in both, but more relevant in all (different margin)
(("doc4", "recent"), 0.1),
(("doc4", "all"), 0.4),
// in both, equally
(("doc5", "recent"), 0.3),
(("doc5", "all"), 0.3),
// only in all, low relevance
(("doc6", "recent"), 0.0),
(("doc6", "all"), 0.3),
// only in all, very high scored
(("doc7", "recent"), 0.0),
(("doc7", "all"), 0.9),
// in both but more relevant in recent (doc2 inverse)
(("doc8", "recent"), 0.4),
(("doc8", "all"), 0.2),
// in both but low (expect doc5 > doc9)
(("doc9", "recent"), 0.1),
(("doc9", "all"), 0.1),
// doc10 missing but handled anyway
]
.iter()
.cloned()
.collect();
let now = Instant::now();
let selected_doc_ids = ia_select(
num_documents_to_select,
&categories,
&doc_ids,
&probability_category_given_query,
&doc_id_category_score,
);
println!("Microseconds elapsed: {}", now.elapsed().as_micros());
println!("Selected doc_ids: {:?} for {:?}", selected_doc_ids, probability_category_given_query);
let probability_category_given_query: HashMap<&str, f64> =
[("recent", 0.1), ("all", 0.9)].iter().cloned().collect();
let now = Instant::now();
let selected_doc_ids = ia_select(
num_documents_to_select,
&categories,
&doc_ids,
&probability_category_given_query,
&doc_id_category_score,
);
println!("Microseconds elapsed: {}", now.elapsed().as_micros());
println!("Selected doc_ids: {:?} for {:?}", selected_doc_ids, probability_category_given_query);
}
@mooreniemi
Copy link
Author

Output:

Microseconds elapsed: 95
Selected doc_ids: ["doc3", "doc8", "doc7", "doc1", "doc5"] for {"recent": 0.9, "all": 0.1}
Microseconds elapsed: 92
Selected doc_ids: ["doc7", "doc3", "doc2", "doc4", "doc5"] for {"all": 0.9, "recent": 0.1}

@mooreniemi
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment