Skip to content

Instantly share code, notes, and snippets.

@kumpera
Created October 8, 2020 20:29
Show Gist options
  • Save kumpera/713d8920d712b87b214610a428e0bd06 to your computer and use it in GitHub Desktop.
Save kumpera/713d8920d712b87b214610a428e0bd06 to your computer and use it in GitHub Desktop.
diff --git a/vowpalwabbit/conditional_contextual_bandit.cc b/vowpalwabbit/conditional_contextual_bandit.cc
index c9041e7e3..29200a31a 100644
--- a/vowpalwabbit/conditional_contextual_bandit.cc
+++ b/vowpalwabbit/conditional_contextual_bandit.cc
@@ -420,70 +420,74 @@ void learn_or_predict(ccb& data, multi_learner& base, multi_ex& examples)
auto decision_scores = examples[0]->pred.decision_scores;
- // for each slot, re-build the cb example and call cb_explore_adf
- size_t slot_id = 0;
- for (example* slot : data.slots)
- {
- // Namespace crossing for slot features.
- data.generated_interactions.clear();
- std::copy(data.original_interactions->begin(), data.original_interactions->end(),
- std::back_inserter(data.generated_interactions));
- calculate_and_insert_interactions(data.shared, data.actions, data.generated_interactions);
- data.shared->interactions = &data.generated_interactions;
- for (auto* ex : data.actions)
+ try {
+ // for each slot, re-build the cb example and call cb_explore_adf
+ size_t slot_id = 0;
+ for (example* slot : data.slots)
{
- ex->interactions = &data.generated_interactions;
- }
+ // Namespace crossing for slot features.
+ data.generated_interactions.clear();
+ std::copy(data.original_interactions->begin(), data.original_interactions->end(),
+ std::back_inserter(data.generated_interactions));
+ calculate_and_insert_interactions(data.shared, data.actions, data.generated_interactions);
+ data.shared->interactions = &data.generated_interactions;
+ for (auto* ex : data.actions)
+ {
+ ex->interactions = &data.generated_interactions;
+ }
- data.include_list.clear();
- build_cb_example<is_learn>(data.cb_ex, slot, data);
+ data.include_list.clear();
+ build_cb_example<is_learn>(data.cb_ex, slot, data);
- if (data.all->audit)
- {
- inject_slot_id<true>(data, data.shared, slot_id);
- }
- else
- {
- inject_slot_id<false>(data, data.shared, slot_id);
- }
+ if (data.all->audit)
+ {
+ inject_slot_id<true>(data, data.shared, slot_id);
+ }
+ else
+ {
+ inject_slot_id<false>(data, data.shared, slot_id);
+ }
- if (has_action(data.cb_ex))
- {
- // the cb example contains at least 1 action
- multiline_learn_or_predict<is_learn>(base, data.cb_ex, examples[0]->ft_offset);
- save_action_scores(data, decision_scores);
- clear_pred_and_label(data);
- }
- else
- {
- // the cb example contains no action => cannot decide
- decision_scores.push_back(data.action_score_pool.get_object());
- }
+ if (has_action(data.cb_ex))
+ {
+ // the cb example contains at least 1 action
+ multiline_learn_or_predict<is_learn>(base, data.cb_ex, examples[0]->ft_offset);
+ save_action_scores(data, decision_scores);
+ clear_pred_and_label(data);
+ }
+ else
+ {
+ // the cb example contains no action => cannot decide
+ decision_scores.push_back(data.action_score_pool.get_object());
+ }
- data.shared->interactions = data.original_interactions;
- for (auto* ex : data.actions)
- {
- ex->interactions = data.original_interactions;
- }
- remove_slot_features(data.shared, slot);
+ data.shared->interactions = data.original_interactions;
+ for (auto* ex : data.actions)
+ {
+ ex->interactions = data.original_interactions;
+ }
+ remove_slot_features(data.shared, slot);
- if (data.all->audit)
- {
- remove_slot_id<true>(data.shared);
- }
- else
- {
- remove_slot_id<false>(data.shared);
+ if (data.all->audit)
+ {
+ remove_slot_id<true>(data.shared);
+ }
+ else
+ {
+ remove_slot_id<false>(data.shared);
+ }
+
+ // Put back the original shared example tag.
+ std::swap(data.shared->tag, slot->tag);
+ slot_id++;
+ data.cb_ex.clear();
}
- // Put back the original shared example tag.
- std::swap(data.shared->tag, slot->tag);
- slot_id++;
- data.cb_ex.clear();
+ delete_cb_labels(data);
+ } catch (...) {
+ printf("OH YOU FAILED CCB!\n");
}
- delete_cb_labels(data);
-
// Restore ccb labels to the example objects.
for (size_t i = 0; i < examples.size(); i++)
{
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment