Skip to content

Instantly share code, notes, and snippets.

@joshreini1
Last active October 27, 2022 21:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save joshreini1/9f30ea53ab6bb54b559c1ea94b42750b to your computer and use it in GitHub Desktop.
Save joshreini1/9f30ea53ab6bb54b559c1ea94b42750b to your computer and use it in GitHub Desktop.
tru.set_data_collection("data_collection")
splits = tru.get_data_splits()
for split in splits:
tru.set_data_collection("data_collection")
tru.set_data_split(split)
xs = tru.get_xs()
ys = tru.get_ys()
tru.set_data_collection("data_collection_v2")
ys_mean = ys.mean()
ys_std = ys.std()
ys_std_from_mean = (ys - ys_mean)/ys_std
if split == 'sf_train_post':
tru.add_data_split("sf_train", pre_data = xs, label_data = ys_std_from_mean, split_type = "train")
xgb_reg = xgb.XGBRegressor()
xgb_reg.fit(xs, ys_std_from_mean)
tru.add_python_model("xgb_v2", xgb_reg)
elif split == 'sf_test':
tru.add_data_split("sf_test", pre_data = xs, label_data = ys_std_from_mean, split_type = "test")
else:
tru.add_data_split(split, pre_data = xs, label_data = ys_std_from_mean, split_type = "oot")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment