Skip to content

Instantly share code, notes, and snippets.

@krsnewwave
Last active May 1, 2022 15:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save krsnewwave/7945a4e574d4ea8316c02ad1770864e1 to your computer and use it in GitHub Desktop.
Save krsnewwave/7945a4e574d4ea8316c02ad1770864e1 to your computer and use it in GitHub Desktop.
# pip install ray tune, and comet ml
from ray.tune.integration.comet import CometLoggerCallback
from functools import partial
from ray.tune.integration.pytorch_lightning import TuneReportCallback
def train_function(model_conf, novelty_per_item, epochs, patience,
train_loader, val_loader, checkpoint_dir=None):
model = CDAE(model_conf, novelty_per_item, num_users, num_items)
# fill up your metrics here
metrics = ["val_loss_epoch",
"val_Prec@20"]
# this is needed so ray tune and PyTorch can communicate
raytune_callback = TuneReportCallback(metrics, on="validation_end")
callbacks = [pl.callbacks.EarlyStopping("val_loss_epoch", mode='min', patience=patience),
raytune_callback]
trainer = pl.Trainer(accelerator="auto", callbacks=callbacks, max_epochs=epochs,
enable_progress_bar = False, log_every_n_steps=1)
trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders = val_loader)
# this is the search space
search_space_conf = {
"hidden_dim": tune.grid_search([50, 100, 200]),
"corruption_ratio": tune.grid_search([0.3, 0.5, 0.8]),
"activation": tune.grid_search(['sigmoid', 'tanh']),
"negative_sample_prob": tune.grid_search([0, 0.5, 1]),
"learning_rate": tune.grid_search([0.1, 0.05, 0.01]),
"wd": tune.grid_search([0, 0.01, 0.001]),
}
# this is the callback
comet_logger = CometLoggerCallback(
api_key=API_KEY,
project_name=PROJECT_NAME,
workspace=WORKSPACE,
tags=["cdae_tuning"]
)
train_function_instance = partial(train_function, model_conf, novelty_per_item,
epochs, patience, train_loader, val_loader, )
analysis = tune.run(
train_function_instance,
name='cdae',
metric="val_Prec@20",
mode='max',
config=search_space_conf,
callbacks=[comet_logger],
# time_budget_s=200
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment