Skip to content

Instantly share code, notes, and snippets.

@xwjiang2010
Last active April 19, 2022 22:10
Show Gist options
  • Save xwjiang2010/2b66f3f7353113cac45a357a56d88910 to your computer and use it in GitHub Desktop.
Save xwjiang2010/2b66f3f7353113cac45a357a56d88910 to your computer and use it in GitHub Desktop.
import math
import torch
import torch.nn as nn
import transformers
import random
import os
import numpy as np
import time
import pandas as pd
from transformers import AutoTokenizer, AutoModelForTokenClassification
#import ray
from ray import tune
seed = 300596
torch.backends.cudnn.deterministic = True
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def hyper_optim(config, model, tokenizer, logging=False, tuning=False, checkpoint_dir=None, seed=seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
# example model
#model = AutoModelForTokenClassification.from_pretrained("xlm-roberta-large-finetuned-conll03-english")
model.to(device)
best_loss = 100
best_accuracy = 0
best_f1 = 0
early_stopping_count = 0
epoch = 0
if checkpoint_dir:
print("Loading from checkpoint.")
path = os.path.join(checkpoint_dir, "checkpoint")
checkpoint = torch.load(path)
model.load_state_dict(checkpoint["model_state_dict"])
early_stopping_count = checkpoint["early_stopping_count"]
epoch = checkpoint["epoch"]
best_loss = checkpoint["best_loss"]
while True:
if not tuning and early_stopping_count >= 7:
break
epoch += 1
early_stopping_count += 1
# this is where the model is being trained and validated
val_loss = random.random()
# update best loss
if val_loss < best_loss:
best_loss = val_loss
early_stopping_count = 0
if tuning:
# checkpoint our current state.
with tune.checkpoint_dir(step=epoch) as checkpoint_dir:
path = os.path.join(checkpoint_dir, "checkpoint")
# Save state to checkpoint file.
torch.save({
"epoch": epoch,
"early_stopping_count": early_stopping_count,
"model_state_dict": model.state_dict(),
"best_loss": best_loss
}, path) # roughly 400MB here, for real model around 1.2GB file size. However, 400MB will already lead to the PLACEMENT_GROUP_REMOVED error. See output below.
if tuning:
tune.report(val_loss=val_loss,epoch=epoch, early_stopping_count=early_stopping_count, best_loss=best_loss)
configuration = {
'model_name': 'Transformer',
'num_labels': 3,
'batch_size': tune.choice([8,16,32]),
'lr': tune.loguniform(1e-5, 1e-1),
'warmup': tune.uniform(0, 0.1),
'w_decay': tune.uniform(0, 0.3),
'n_epochs': 30,
'max_length': 512
}
import gc
torch.cuda.empty_cache()
gc.collect()
# If there's a GPU available...
if torch.cuda.is_available():
# Tell PyTorch to use the GPU.
device = torch.device("cuda")
print('There are %d GPU(s) available.' % torch.cuda.device_count())
print('We will use the GPU:', torch.cuda.get_device_name(0))
# If not...
else:
print('No GPU available, using the CPU instead.')
device = torch.device("cpu")
from ray.tune.schedulers.hb_bohb import HyperBandForBOHB
from ray.tune.suggest.bohb import TuneBOHB
tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-large-finetuned-conll03-english")
model = AutoModelForTokenClassification.from_pretrained("xlm-roberta-large-finetuned-conll03-english")
bohb_hyperband = HyperBandForBOHB(
time_attr="training_iteration",
max_t=configuration["n_epochs"],
reduction_factor=3,
stop_last_trials=True)
bohb_search = TuneBOHB(seed=300596)
bohb_search = tune.suggest.ConcurrencyLimiter(bohb_search, max_concurrent=4)
def stopper(trial_id, result):
return result["early_stopping_count"] >=7
import ray
ray.init(address="auto")
from ray.tune.syncer import SyncConfig
analysis = tune.run(
tune.with_parameters(hyper_optim, model=model, tokenizer=tokenizer, logging=False, tuning=True),
name="bohb_test",
scheduler=bohb_hyperband,
metric="val_loss",
mode="min",
verbose=3,
search_alg=bohb_search,
stop=stopper,
sync_config=SyncConfig(sync_on_checkpoint=False),
keep_checkpoints_num=1,
checkpoint_score_attr="training_iteration",
resources_per_trial={"cpu": 6},
num_samples=4,
reuse_actors=True,
config=configuration)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment