Skip to content

Instantly share code, notes, and snippets.

@TkrUdagawa
Created March 11, 2016 10:16
Show Gist options
  • Save TkrUdagawa/7a6e2ba33a450962379a to your computer and use it in GitHub Desktop.
Save TkrUdagawa/7a6e2ba33a450962379a to your computer and use it in GitHub Desktop.
diff --git jubatus/core/bandit/bandit_factory_test.cpp jubatus/core/bandit/bandit_factory_test.cpp
index 559b5be..a1117c6 100644
--- jubatus/core/bandit/bandit_factory_test.cpp
+++ jubatus/core/bandit/bandit_factory_test.cpp
@@ -48,6 +48,7 @@ TEST(bandit_factory, ucb1) {
TEST(bandit_factory, ts) {
json::json js(new json::json_object);
+ js["assume_unrewarded"] = json::to_json(true);
common::jsonconfig::config conf(js);
shared_ptr<bandit_base> p = bandit_factory::create("ts", conf);
EXPECT_EQ("ts", p->name());
diff --git jubatus/core/bandit/ts.cpp jubatus/core/bandit/ts.cpp
index b5c48d4..c9ce863 100644
--- jubatus/core/bandit/ts.cpp
+++ jubatus/core/bandit/ts.cpp
@@ -53,8 +53,8 @@ std::string ts::select_arm(const std::string& player_id) {
result = arms[i];
}
}
+ s_.notify_selected(player_id, result);
return result;
- //return arms[select_by_weights(weights, rand_)];
}
bool ts::register_arm(const std::string& arm_id) {
@@ -64,12 +64,14 @@ bool ts::delete_arm(const std::string& arm_id) {
return s_.delete_arm(arm_id);
}
-bool ts::register_reward(const std::string& player_id,
- const std::string& arm_id,
- double reward) {
- if( (reward != 0.0) && (reward != 1.0) ){
+bool ts::register_reward(
+ const std::string& player_id,
+ const std::string& arm_id,
+ double reward) {
+ // Thompson sampling assumes binary rewards
+ if ((reward != 0.0) && (reward != 1.0)) {
throw JUBATUS_EXCEPTION(
- common::exception::runtime_error("reward is not in {0,1}")); //Thompson sampling assumes binary rewards
+ common::exception::runtime_error("reward is not in {0,1}"));
}
const std::vector<std::string>& arms = s_.get_arm_ids();
size_t i = std::find(arms.begin(), arms.end(), arm_id) - arms.begin();
diff --git jubatus/util/math/random.h jubatus/util/math/random.h
index 5c54975..0f139c7 100644
--- jubatus/util/math/random.h
+++ jubatus/util/math/random.h
@@ -134,7 +134,7 @@ public:
}else{
x = -log((1.0-u1)/(c*alpha));
if( log(u2) <= (alpha-1)*log(x) ){
- return true;
+ return x;
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment