Skip to content

Instantly share code, notes, and snippets.

@Ab1992ao
Created June 2, 2021 13:43
Show Gist options
  • Save Ab1992ao/4af683278e911de43df90fc3dda185ad to your computer and use it in GitHub Desktop.
Save Ab1992ao/4af683278e911de43df90fc3dda185ad to your computer and use it in GitHub Desktop.
training config for sbert mlttsk
class TrainingConfig:
def __init__(self, **kwargs):
self.model_name = "sbert_tuned"
self.data_dir = "/content/drive/MyDrive/"
self.module_path = "/content/bert_module/"
self.pretrained_ckpt = None
self.generation = "sbert"
self.ctx_len = 24
self.dim = 768
self.n_tune = 12
self.lr = 2e-6
self.train_bert = True
self.tune_embs =True
self.head_dropout_rate = 0.1
self.tagger_loss_weight = 0.1
self.toxic_loss_weight = 0.1
self.paraphrase_loss_weight = 0.2
self.sample_dropout = 0.1
self.use_par_head = True
self.use_ner_head = True
self.use_toxic_head = True
self.eval_on_start = False
self.batch_size = 24
self.eval_batch_size = 256
self.n_epochs = 256
self.n_tune_epochs = 25
self.epoch_steps = 512
self.valid_steps = 32
self.metric_steps = 16
self.n_toxic_tags = 2
self.optimizer = 'adam'
self.n_tags = 13
self.type_loss = 'softmax'
for k, v in kwargs.items():
setattr(self, k, v)
self.print_config()
def to_dict(self):
return self.__dict__
def to_string(self):
conf_dict = self.to_dict()
ser_string = ""
for k in conf_dict:
ser_string += f"_{k}_{conf_dict[k]}"
return ser_string
def print_config(self):
log_this(50 * '-')
log_this(f"Using config:")
mlen = max([len(v) for v in self.to_dict().keys()])
for k, v in self.to_dict().items():
log_this(f"{k}{(mlen - len(k)) * ' '} = {v}")
log_this(50 * '-')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment