Skip to content

Instantly share code, notes, and snippets.

@henry0312
Last active June 24, 2023 07:57
Show Gist options
  • Save henry0312/c0d84c679ab9583f4500ce1653ebb889 to your computer and use it in GitHub Desktop.
Save henry0312/c0d84c679ab9583f4500ce1653ebb889 to your computer and use it in GitHub Desktop.
k-fold cross validation with Optuna
class LocalObjective(object):
def __init__(self, params):
self.params = dict(params)
def __call__(self, trial):
# fit
return loss
class GlobalObjective(object):
def __init__(self, k):
self.k = k
self.local_studies = [optuna.create_study() for _ in range(k)]
def __call__(self, trial):
params # sample params
for study in self.local_studies:
study.optimize(LocalObjective(params), n_trials=1)
# store loss
# calculate loss
return loss
def main():
study = optuna.create_study()
global_obj = GlobalObjective(k=5)
study.optimize(global_obj, n_trials=100)
if __name__ == '__main__':
main()
@stoyky
Copy link

stoyky commented May 25, 2023

Thank you for this code, very interesting repurpose of Optuna mechanisms!

@joshuasv
Copy link

joshuasv commented Jun 24, 2023

Did any of you get this error?

OSError: [Errno 24] Too many open files

Edit (fixed)

If you use torch.utils.data.DataLoader do not set the attribute persistent_workers=True, it will cause the aforementioned error (or any related, for instance if using a optuna.storages.JournalStorage you may get OSError: [Errno 24] Too many open files: './example.journal'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment