Skip to content

Instantly share code, notes, and snippets.

@albertotb
Created July 15, 2016 15:02
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save albertotb/1bad123363b186267e3aeaa26610b54b to your computer and use it in GitHub Desktop.
Save albertotb/1bad123363b186267e3aeaa26610b54b to your computer and use it in GitHub Desktop.
if __name__ == "__main__":
if len(sys.argv) < 3 or len(sys.argv) > 4:
print "usage: {0} TRAIN TEST VAL".format(sys.argv[0])
sys.exit(1)
train = np.loadtxt(sys.argv[1])
test = np.loadtxt(sys.argv[2])
val = np.loadtxt(sys.argv[3])
ntrain = train.shape[0]
nval = val.shape[0]
train = np.concatenate((train, val))
test_fold = -1*np.ones(ntrain+nval)
test_fold[ntrain:] = 0
cv = PredefinedSplit(test_fold)
# you can pass the object cv to functions such as GridSearchCV
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment