Skip to content

Instantly share code, notes, and snippets.

@xwjiang2010
Last active August 18, 2021 23:09
Show Gist options
  • Save xwjiang2010/8f6300776a40ca05e72e51bc3135f903 to your computer and use it in GitHub Desktop.
Save xwjiang2010/8f6300776a40ca05e72e51bc3135f903 to your computer and use it in GitHub Desktop.
# from torch.utils.tensorboard import SummaryWriter
import ray
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler
# writer = SummaryWriter()
# import threading
# l = threading.Lock()
def train_function(config):
print("======================Entering train func===================================")
# writer.close()
# l.acquire()
# l.release()
for i in range(10):
tune.report(loss=0, accuracy=1)
print("======================Leaving train func===================================")
def main(num_samples=10, max_num_epochs=10):
config = {
"lr": 0.001,
}
scheduler = ASHAScheduler(
metric="loss",
mode="min",
max_t=max_num_epochs,
grace_period=1,
reduction_factor=2)
reporter = CLIReporter(
# parameter_columns=["l1", "l2", "lr", "batch_size"],
metric_columns=["loss", "accuracy", "training_iteration"])
result = tune.run(
train_function,
resources_per_trial={"cpu": 1, "gpu": 1},
config=config,
num_samples=num_samples,
scheduler=scheduler,
progress_reporter=reporter)
if __name__ == "__main__":
import logging
ray.init(local_mode=True, num_cpus=1, num_gpus=0, logging_level=logging.DEBUG)
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment