Skip to content

Instantly share code, notes, and snippets.

@dchaplinsky
Created February 18, 2023 17:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dchaplinsky/7d9d6a16799a2e611245acb33ea6e379 to your computer and use it in GitHub Desktop.
Save dchaplinsky/7d9d6a16799a2e611245acb33ea6e379 to your computer and use it in GitHub Desktop.
import os.path
from flair.data import Dictionary
from flair.models import LanguageModel
from flair.trainers.language_model_trainer import LanguageModelTrainer, TextCorpus
def train_flair_embeddings(
corpus_path="/data/ubertext/for_flair",
dictionary_path="/home/dima/Projects/flair_embeddings/flair_dictionary.pkl",
lm_file="./language_model_forward_no_amp_accum_grad_fixed",
is_forward_lm=True,
hidden_size=1024,
sequence_length=250,
mini_batch_size=256,
max_epochs=20,
gpus=1,
):
# load the default character dictionary
dictionary: Dictionary = Dictionary.load(dictionary_path)
# get your corpus, process forward and at the character level
corpus = TextCorpus(corpus_path, dictionary, is_forward_lm, character_level=True)
checkpoint = os.path.join(lm_file, "checkpoint.pt")
if os.path.exists(checkpoint):
trainer = LanguageModelTrainer.load_checkpoint(checkpoint, corpus)
else:
# train your language model
language_model = LanguageModel(dictionary, is_forward_lm, hidden_size=hidden_size, nlayers=1)
trainer = LanguageModelTrainer(language_model, corpus)
trainer.train(
lm_file,
sequence_length=sequence_length,
mini_batch_size=mini_batch_size,
max_epochs=max_epochs,
use_amp=False,
checkpoint=True,
)
train_flair_embeddings()
import os.path
from flair.data import Dictionary
from flair.models import LanguageModel
from flair.trainers.language_model_trainer import LanguageModelTrainer, TextCorpus
import flair
import torch
flair.device = torch.device("cuda:1")
def train_flair_embeddings(
corpus_path="/data/ubertext/for_flair",
dictionary_path="/home/dima/Projects/flair_embeddings/flair_dictionary.pkl",
lm_file="./language_model_backward_no_amp_accum_grad_fixed",
is_forward_lm=True,
hidden_size=1024,
sequence_length=250,
mini_batch_size=200,
max_epochs=25,
gpus=1,
):
# load the default character dictionary
dictionary: Dictionary = Dictionary.load(dictionary_path)
# get your corpus, process forward and at the character level
corpus = TextCorpus(corpus_path, dictionary, is_forward_lm, character_level=True)
checkpoint = os.path.join(lm_file, "checkpoint.pt")
if os.path.exists(checkpoint):
trainer = LanguageModelTrainer.load_checkpoint(checkpoint, corpus)
else:
# train your language model
language_model = LanguageModel(dictionary, is_forward_lm, hidden_size=hidden_size, nlayers=1)
trainer = LanguageModelTrainer(language_model, corpus)
trainer.train(
lm_file,
sequence_length=sequence_length,
mini_batch_size=mini_batch_size,
max_epochs=max_epochs,
use_amp=False,
checkpoint=True,
)
train_flair_embeddings(is_forward_lm=False, mini_batch_size=480)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment