Created
September 12, 2022 12:29
-
-
Save tchaton/0346f7af32d182afdd1ef7231ce5d1ac to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os.path as ops | |
from lightning import LightningApp | |
from lightning_hpo import HPOCloudCompute, Sweep | |
from lightning_hpo.algorithm.optuna import OptunaAlgorithm | |
from lightning_hpo.distributions.distributions import Categorical, IntUniform, LogUniform, Uniform | |
app = LightningApp( | |
Sweep( | |
script_path=ops.join(ops.dirname(__file__), "scripts/train.py"), | |
n_trials=5, | |
simultaneous_trials=2, | |
distributions={ | |
"model.lr": LogUniform(0.001, 0.1), | |
"model.gamma": Uniform(0.5, 0.8), | |
"data.batch_size": Categorical([16, 32, 64]), | |
"trainer.max_epochs": IntUniform(1, 5), | |
}, | |
algorithm=OptunaAlgorithm(direction="maximize"), | |
cloud_compute=HPOCloudCompute("gpu-fast-multi", count=2), # 2 * 4 V100 | |
framework="pytorch_lightning", | |
logger="wandb", | |
sweep_id="Optimizing a Simple CNN over MNIST with Lightning HPO", | |
) | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment