Skip to content

Instantly share code, notes, and snippets.

@suma
Last active August 29, 2015 14:02
Show Gist options
  • Save suma/2dfac4356d8bd1bb14ac to your computer and use it in GitHub Desktop.
Save suma/2dfac4356d8bd1bb14ac to your computer and use it in GitHub Desktop.
linear_mixer test for nearest_neighbor (jubatus_core).
diff --git a/jubatus/core/driver/nearest_neighbor_test.cpp b/jubatus/core/driver/nearest_neighbor_test.cpp
index 554ed51..00f00bf 100644
--- a/jubatus/core/driver/nearest_neighbor_test.cpp
+++ b/jubatus/core/driver/nearest_neighbor_test.cpp
@@ -145,10 +145,14 @@ class nearest_neighbor_test
: public ::testing::TestWithParam<
shared_ptr<core::nearest_neighbor::nearest_neighbor_base> > {
protected:
- void SetUp() {
- nn_driver_ = shared_ptr<core::driver::nearest_neighbor>(
+ shared_ptr<core::driver::nearest_neighbor> create_driver() const {
+ return shared_ptr<core::driver::nearest_neighbor>(
new core::driver::nearest_neighbor(GetParam(), make_fv_converter()));
}
+
+ void SetUp() {
+ nn_driver_ = create_driver();
+ }
void TearDown() {
nn_driver_->clear();
}
@@ -300,6 +304,47 @@ TEST_P(nearest_neighbor_test, small) {
nn_driver_->neighbor_row_from_data(create_datum_2d(1.f, 1.f), 2);
}
+TEST_P(nearest_neighbor_test, small_mix) {
+ framework::linear_mixable* nn_mixable =
+ dynamic_cast<framework::linear_mixable*>(nn_driver_->get_mixable());
+ shared_ptr<driver::nearest_neighbor> other = create_driver();
+ framework::linear_mixable* other_mixable =
+ dynamic_cast<framework::linear_mixable*>(other->get_mixable());
+ ASSERT_TRUE(nn_mixable);
+ ASSERT_TRUE(other_mixable);
+
+ nn_driver_->set_row("a", single_str_datum("x", "hoge"));
+ nn_driver_->set_row("b", single_str_datum("y", "fuga"));
+
+ msgpack::sbuffer data;
+ {
+ core::framework::stream_writer<msgpack::sbuffer> st(data);
+ core::framework::jubatus_packer jp(st);
+ core::framework::packer pk(jp);
+ nn_mixable->get_diff(pk);
+ }
+ {
+ msgpack::sbuffer sbuf;
+ core::framework::stream_writer<msgpack::sbuffer> st(sbuf);
+ core::framework::jubatus_packer jp(st);
+ core::framework::packer pk(jp);
+ other_mixable->get_diff(pk);
+
+ msgpack::unpacked msg;
+ msgpack::unpack(&msg, sbuf.data(), sbuf.size());
+ std::cout << msg.get() << std::endl;
+ framework::diff_object diff = other_mixable->convert_diff_object(msg.get());
+ std::cout << "hello " << msg.get() << std::endl;
+
+ msgpack::unpacked data_msg;
+ msgpack::unpack(&data_msg, data.data(), data.size());
+
+ other_mixable->mix(data_msg.get(), diff);
+ other_mixable->put_diff(diff);
+ }
+}
+
+
INSTANTIATE_TEST_CASE_P(nearest_neighbor_test_instance,
nearest_neighbor_test,
testing::ValuesIn(create_nearest_neighbor_bases()));
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment