Created
February 8, 2013 16:13
-
-
Save gsk3/4740012 to your computer and use it in GitHub Desktop.
Solution and benchmarking to http://stackoverflow.com/questions/14684539/sample-with-a-max/14684701#14684701
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
library(inline) | |
set.seed(123) | |
x <- sample(1:100,200, replace = TRUE) | |
cpp_src <- ' | |
Rcpp::IntegerVector xa = clone(x); // Vector to be sampled | |
Rcpp::IntegerVector na(n); // Number of elements in solution | |
Rcpp::IntegerVector sa(s); // Sum of solution | |
int nsampled; | |
int currentSum; | |
int dropRandomIndex; | |
int numZeroes; | |
Rcpp::IntegerVector remainingQuantity(1); | |
int maxAttempts = 100; | |
// Create container for our results | |
Rcpp::IntegerVector res(maxAttempts); | |
std::fill( res.begin(), res.end(), NA_INTEGER ); | |
// Calculate min/max so that we can draw random integers from within range | |
Rcpp::IntegerVector::iterator mn = std::min_element(xa.begin(), xa.end()) ; | |
Rcpp::IntegerVector::iterator mx = std::max_element(xa.begin(), xa.end()) ; | |
std::cout << "mx = " << *mx << std::endl; | |
// Now draw repeatedly | |
nsampled = 0; | |
for( int i = 0; i < maxAttempts; i++ ) { | |
std::cout << "\\n" << i; | |
int r = *mn + (rand() % (int)(*mx - *mn + 1)); | |
res[i] = xa[r+1]; | |
// Calculate n and s for current loop iteration | |
numZeroes = 0; | |
for( int j = 0; j < maxAttempts; j++) | |
if(res[j]==0) numZeroes++; | |
std::cout << " nz= " << numZeroes ; | |
nsampled = maxAttempts - sum( is_na(res) ) - numZeroes - 1; | |
currentSum = std::accumulate(res.begin(),res.begin()+i,0); // Cant just use Rcpp sugar sum() here because it freaks at the NAs | |
std::cout << " nsamp= " << nsampled << " sum= " << currentSum; | |
if(nsampled == na[0]-1) { | |
std::cout << " One element away. "; | |
remainingQuantity[0] = sa[0] - currentSum; | |
std::cout << "remainingQuantity = " << remainingQuantity[0]; | |
if( (remainingQuantity[0] > 0) && (remainingQuantity[0]) < *mx ) { | |
std::cout << "Within range. Prepare the secret (cheating) weapon!\\n"; | |
std::cout << sa[0] << " "; | |
std::cout << currentSum << " "; | |
std::cout << remainingQuantity[0] << std::endl; | |
if( i != maxAttempts ) { | |
std::cout << "Safe to add one last element on the end. Doing so.\\n"; | |
res[i] = remainingQuantity[0]; | |
} | |
currentSum = sa[0]; | |
nsampled++; | |
if(nsampled == na[0] && currentSum == sa[0]) std::cout << "It should end after this...nsamp= " << nsampled << " and currentSum= " << currentSum << std::endl; | |
break; | |
} else { | |
std::cout << "Out of striking distance. Dropping random element\\n"; | |
dropRandomIndex = 0 + (rand() % (int)(i - 0 + 1)); | |
res[dropRandomIndex] = 0; | |
} | |
} | |
if(nsampled == na[0] && currentSum == sa[0]) { | |
std::cout << "Success!\\n"; | |
for(int l = 0; l <= i+1; l++) | |
std::cout << res[l] << " " ; | |
break; | |
} | |
if(nsampled == na[0] && currentSum != sa[0]) { | |
std::cout << "Reached number of elements but sum is "; | |
if(currentSum > sa[0]) { | |
std::cout << "Too high. Blitz everything and start over!\\n"; | |
for(int k = 0; k < res.size(); k++) { | |
res[k] = NA_INTEGER; | |
} | |
} else { | |
std::cout << "Too low. \\n"; | |
} | |
} | |
if( nsampled < na[0] && currentSum >= sa[0] ) { | |
std::cout << "Too few elements but at or above the sum cutoff. Dropping a random element and trying again.\\n"; | |
dropRandomIndex = 0 + (rand() % (int)(i - 0 + 1)); | |
res[dropRandomIndex] = 0; | |
} | |
} | |
return res; | |
' | |
sumto <- cxxfunction( signature(x="integer", n="integer", s="integer"), body=cpp_src, plugin="Rcpp", verbose=TRUE ) | |
testresult <- sumto(x=x, n=20L, s=1000L) | |
testresult <- testresult[!is.na(testresult)] | |
testresult <- testresult[testresult!=0] | |
testresult | |
cumsum(testresult) | |
length(testresult) | |
##### Test Other functions ###### | |
# ABF's R function | |
n <- 20L | |
target <- 1000L | |
vec <- seq(100) | |
set.seed(123) | |
# R repeat loop | |
sumto_repeat <- function(vec,n,target) { | |
res <- integer() | |
repeat { | |
cat("begin:",sum(res),length(res),"\n") | |
res <- c( res, sample(vec,1) ) | |
if( sum(res)<target & length(res)==(n-1) ) { | |
res[length(res)+1] <- target - sum(res) | |
} | |
# cat("mid:",sum(res),length(res),"\n") | |
if(sum(res)>target) res <- res[-length(res)] | |
if( length(res)>n | length(res)<n & sum(res)==target ) { | |
res <- res[-sample(seq(length(res)),1)] | |
} | |
# cat("end:",sum(res),length(res),"\n") | |
# cat(dput(res),"\n") | |
if( sum(res)==target & length(res)==n ) break | |
} | |
res | |
} | |
# Ananda's R function | |
SampleToSum <- function(Target = 100, VecLen = 10, | |
InRange = 1:100, Tolerance = 2, | |
showSum = TRUE) { | |
Res <- vector() | |
while ( TRUE ) { | |
Res <- round(diff(c(0, sort(runif(VecLen - 1)), 1)) * Target) | |
if ( all(Res > 0) & | |
all(Res >= min(InRange)) & | |
all(Res <= max(InRange)) & | |
abs((sum(Res) - Target)) <= Tolerance ) { break } | |
} | |
if (isTRUE(showSum)) cat("Total = ", sum(Res), "\n") | |
Res | |
} | |
# Greg Snow's elegant solution | |
# But doesn't seem to always work if the vector differs from the sample size: | |
#> sum(as.vector(table( c( seq(n), sample(vec, target, replace=TRUE) ) ))) | |
#[1] 1020 | |
gs <- function(vec,n,target) { | |
as.vector(table( c( seq(n), sample(vec, target, replace=TRUE) ) )) | |
} | |
# Benchmark | |
m <- microbenchmark( | |
SampleToSum(Tolerance=0), | |
sumto(vec, n, target), | |
sumto_repeat(vec,n,target), | |
gs(vec,n,target), | |
times=10 | |
) | |
m | |
library(ggplot2) | |
library(taRifx) | |
autoplot(m) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
You spent too much time on this question ;)
I don't totally understand Greg Snow's solution. It just doesn't seem to work for me. The
length
is all wrong, the range in the results are very narrow---I think it answers a different question. Also, to get the correct output value, do remember thatvec
should bevec - n
(which is why the sum is 20 off in your test).