Last active
October 25, 2019 03:41
-
-
Save lambday/96894241f039e8eb5c09338a1570d66a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
TEST(CrossValidationMMD, biased_full) | |
{ | |
const int32_t seed = 5; | |
const index_t n=24; | |
const index_t m=15; | |
const index_t dim=2; | |
const index_t num_null_samples=5; | |
const index_t num_folds=3; | |
const index_t num_runs=2; | |
const index_t num_kernels=4; | |
const index_t cache_size=10; | |
const float64_t difference=0.5; | |
const float64_t alpha=0.05; | |
const auto stype=ST_BIASED_FULL; | |
std::mt19937_64 prng1(seed); | |
std::mt19937_64 prng2(seed); | |
auto gen_p=some<CMeanShiftDataGenerator>(0, dim, 0); | |
auto gen_q=some<CMeanShiftDataGenerator>(difference, dim, 0); | |
gen_p->put("seed", seed); | |
gen_q->put("seed", seed); | |
auto feats_p=gen_p->get_streamed_features(n); | |
auto feats_q=gen_q->get_streamed_features(m); | |
auto merged_feats=static_cast<CDenseFeatures<float64_t>*> | |
(feats_p->create_merged_copy(feats_q)); | |
KernelManager kernel_mgr; | |
for (auto i=0; i<num_kernels; ++i) | |
{ | |
auto width=pow(2, i); | |
auto kernel=new CGaussianKernel(cache_size, width); | |
kernel_mgr.push_back(kernel); | |
} | |
auto distance_instance=kernel_mgr.get_distance_instance(); | |
distance_instance->init(merged_feats, merged_feats); | |
auto precomputed_distance=some<CCustomDistance>(); | |
auto distance_matrix=distance_instance->get_distance_matrix<float32_t>(); | |
precomputed_distance->set_triangle_distance_matrix_from_full(distance_matrix.data(), n+m, n+m); | |
SG_UNREF(distance_instance); | |
kernel_mgr.set_precomputed_distance(precomputed_distance); | |
auto cv=CrossValidationMMD(n, m, num_folds, num_null_samples, prng1); | |
cv.m_stype=stype; | |
cv.m_alpha=alpha; | |
cv.m_num_runs=num_runs; | |
cv.m_rejections=SGMatrix<float64_t>(num_runs*num_folds, num_kernels); | |
cv(kernel_mgr, prng1); | |
kernel_mgr.unset_precomputed_distance(); | |
SGVector<int64_t> dummy_labels_p(n); | |
SGVector<int64_t> dummy_labels_q(m); | |
auto kfold_p=some<CCrossValidationSplitting>(new CBinaryLabels(dummy_labels_p), num_folds); | |
auto kfold_q=some<CCrossValidationSplitting>(new CBinaryLabels(dummy_labels_q), num_folds); | |
kfold_p->put("seed", seed); | |
kfold_q->put("seed", seed); | |
auto permutation_mmd=PermutationMMD(); | |
permutation_mmd.m_stype=stype; | |
permutation_mmd.m_num_null_samples=num_null_samples; | |
for (auto k=0; k<num_kernels; ++k) | |
{ | |
CKernel* kernel=kernel_mgr.kernel_at(k); | |
for (auto current_run=0; current_run<num_runs; ++current_run) | |
{ | |
kfold_p->build_subsets(); | |
kfold_q->build_subsets(); | |
for (auto current_fold=0; current_fold<num_folds; ++current_fold) | |
{ | |
auto current_train_subset_p=kfold_p->generate_subset_inverse(current_fold); | |
auto current_train_subset_q=kfold_q->generate_subset_inverse(current_fold); | |
feats_p->add_subset(current_train_subset_p); | |
feats_q->add_subset(current_train_subset_q); | |
permutation_mmd.m_n_x=feats_p->get_num_vectors(); | |
permutation_mmd.m_n_y=feats_q->get_num_vectors(); | |
auto current_merged_feats=static_cast<CDenseFeatures<float64_t>*> | |
(feats_p->create_merged_copy(feats_q)); | |
kernel->init(current_merged_feats, current_merged_feats); | |
auto p_value=permutation_mmd.p_value(kernel->get_kernel_matrix<float32_t>(), prng2); | |
EXPECT_EQ(cv.m_rejections(current_run*num_folds+current_fold, k), p_value<alpha); | |
kernel->remove_lhs_and_rhs(); | |
feats_p->remove_subset(); | |
feats_q->remove_subset(); | |
} | |
} | |
} | |
SG_UNREF(feats_p); | |
SG_UNREF(feats_q); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment