Skip to content

Instantly share code, notes, and snippets.

@lambday
Last active October 25, 2019 03:41
Show Gist options
  • Save lambday/96894241f039e8eb5c09338a1570d66a to your computer and use it in GitHub Desktop.
Save lambday/96894241f039e8eb5c09338a1570d66a to your computer and use it in GitHub Desktop.
TEST(CrossValidationMMD, biased_full)
{
const int32_t seed = 5;
const index_t n=24;
const index_t m=15;
const index_t dim=2;
const index_t num_null_samples=5;
const index_t num_folds=3;
const index_t num_runs=2;
const index_t num_kernels=4;
const index_t cache_size=10;
const float64_t difference=0.5;
const float64_t alpha=0.05;
const auto stype=ST_BIASED_FULL;
std::mt19937_64 prng1(seed);
std::mt19937_64 prng2(seed);
auto gen_p=some<CMeanShiftDataGenerator>(0, dim, 0);
auto gen_q=some<CMeanShiftDataGenerator>(difference, dim, 0);
gen_p->put("seed", seed);
gen_q->put("seed", seed);
auto feats_p=gen_p->get_streamed_features(n);
auto feats_q=gen_q->get_streamed_features(m);
auto merged_feats=static_cast<CDenseFeatures<float64_t>*>
(feats_p->create_merged_copy(feats_q));
KernelManager kernel_mgr;
for (auto i=0; i<num_kernels; ++i)
{
auto width=pow(2, i);
auto kernel=new CGaussianKernel(cache_size, width);
kernel_mgr.push_back(kernel);
}
auto distance_instance=kernel_mgr.get_distance_instance();
distance_instance->init(merged_feats, merged_feats);
auto precomputed_distance=some<CCustomDistance>();
auto distance_matrix=distance_instance->get_distance_matrix<float32_t>();
precomputed_distance->set_triangle_distance_matrix_from_full(distance_matrix.data(), n+m, n+m);
SG_UNREF(distance_instance);
kernel_mgr.set_precomputed_distance(precomputed_distance);
auto cv=CrossValidationMMD(n, m, num_folds, num_null_samples, prng1);
cv.m_stype=stype;
cv.m_alpha=alpha;
cv.m_num_runs=num_runs;
cv.m_rejections=SGMatrix<float64_t>(num_runs*num_folds, num_kernels);
cv(kernel_mgr, prng1);
kernel_mgr.unset_precomputed_distance();
SGVector<int64_t> dummy_labels_p(n);
SGVector<int64_t> dummy_labels_q(m);
auto kfold_p=some<CCrossValidationSplitting>(new CBinaryLabels(dummy_labels_p), num_folds);
auto kfold_q=some<CCrossValidationSplitting>(new CBinaryLabels(dummy_labels_q), num_folds);
kfold_p->put("seed", seed);
kfold_q->put("seed", seed);
auto permutation_mmd=PermutationMMD();
permutation_mmd.m_stype=stype;
permutation_mmd.m_num_null_samples=num_null_samples;
for (auto k=0; k<num_kernels; ++k)
{
CKernel* kernel=kernel_mgr.kernel_at(k);
for (auto current_run=0; current_run<num_runs; ++current_run)
{
kfold_p->build_subsets();
kfold_q->build_subsets();
for (auto current_fold=0; current_fold<num_folds; ++current_fold)
{
auto current_train_subset_p=kfold_p->generate_subset_inverse(current_fold);
auto current_train_subset_q=kfold_q->generate_subset_inverse(current_fold);
feats_p->add_subset(current_train_subset_p);
feats_q->add_subset(current_train_subset_q);
permutation_mmd.m_n_x=feats_p->get_num_vectors();
permutation_mmd.m_n_y=feats_q->get_num_vectors();
auto current_merged_feats=static_cast<CDenseFeatures<float64_t>*>
(feats_p->create_merged_copy(feats_q));
kernel->init(current_merged_feats, current_merged_feats);
auto p_value=permutation_mmd.p_value(kernel->get_kernel_matrix<float32_t>(), prng2);
EXPECT_EQ(cv.m_rejections(current_run*num_folds+current_fold, k), p_value<alpha);
kernel->remove_lhs_and_rhs();
feats_p->remove_subset();
feats_q->remove_subset();
}
}
}
SG_UNREF(feats_p);
SG_UNREF(feats_q);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment