Last active
May 1, 2022 15:25
-
-
Save krsnewwave/7945a4e574d4ea8316c02ad1770864e1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# pip install ray tune, and comet ml | |
from ray.tune.integration.comet import CometLoggerCallback | |
from functools import partial | |
from ray.tune.integration.pytorch_lightning import TuneReportCallback | |
def train_function(model_conf, novelty_per_item, epochs, patience, | |
train_loader, val_loader, checkpoint_dir=None): | |
model = CDAE(model_conf, novelty_per_item, num_users, num_items) | |
# fill up your metrics here | |
metrics = ["val_loss_epoch", | |
"val_Prec@20"] | |
# this is needed so ray tune and PyTorch can communicate | |
raytune_callback = TuneReportCallback(metrics, on="validation_end") | |
callbacks = [pl.callbacks.EarlyStopping("val_loss_epoch", mode='min', patience=patience), | |
raytune_callback] | |
trainer = pl.Trainer(accelerator="auto", callbacks=callbacks, max_epochs=epochs, | |
enable_progress_bar = False, log_every_n_steps=1) | |
trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders = val_loader) | |
# this is the search space | |
search_space_conf = { | |
"hidden_dim": tune.grid_search([50, 100, 200]), | |
"corruption_ratio": tune.grid_search([0.3, 0.5, 0.8]), | |
"activation": tune.grid_search(['sigmoid', 'tanh']), | |
"negative_sample_prob": tune.grid_search([0, 0.5, 1]), | |
"learning_rate": tune.grid_search([0.1, 0.05, 0.01]), | |
"wd": tune.grid_search([0, 0.01, 0.001]), | |
} | |
# this is the callback | |
comet_logger = CometLoggerCallback( | |
api_key=API_KEY, | |
project_name=PROJECT_NAME, | |
workspace=WORKSPACE, | |
tags=["cdae_tuning"] | |
) | |
train_function_instance = partial(train_function, model_conf, novelty_per_item, | |
epochs, patience, train_loader, val_loader, ) | |
analysis = tune.run( | |
train_function_instance, | |
name='cdae', | |
metric="val_Prec@20", | |
mode='max', | |
config=search_space_conf, | |
callbacks=[comet_logger], | |
# time_budget_s=200 | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment