Skip to content

Instantly share code, notes, and snippets.

@emrul
Created May 27, 2019 00:33
Show Gist options
  • Save emrul/74486783e9d750f2cb08695bf26719da to your computer and use it in GitHub Desktop.
Save emrul/74486783e9d750f2cb08695bf26719da to your computer and use it in GitHub Desktop.
import torch
from flair.data import Sentence, Dictionary
from flair.data_fetcher import NLPTaskDataFetcher, NLPTask
from flair.embeddings import TokenEmbeddings, WordEmbeddings, StackedEmbeddings, CharacterEmbeddings, \
PooledFlairEmbeddings, FlairEmbeddings
from flair.visual.training_curves import Plotter
from flair.trainers import ModelTrainer
from flair.models import SequenceTagger
from flair.datasets import ColumnCorpus
# 1. get the corpus
def train():
columns = {0: 'text', 1: 'ner'}
data_folder = 'training/sequences'
corpus = ColumnCorpus(data_folder, columns, in_memory=False) # NLPTaskDataFetcher.load_column_corpus(data_folder, columns) #.downsample(0.05)
#corpus = NLPTaskDataFetcher.load_corpus(NLPTask.WNUT_17).downsample(0.1)
#for sentence in corpus.train:
# print(len(sentence))
print(corpus)
# 2. what tag do we want to predict?
tag_type = 'ner'
# 3. make the tag dictionary from the corpus
tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type)
print(tag_dictionary.idx2item)
# 4. initialize embeddings
embeddings = StackedEmbeddings(embeddings=[
WordEmbeddings('glove'),
# contextual string embeddings, forward
FlairEmbeddings('news-forward', use_cache=True, chars_per_chunk=64),
FlairEmbeddings('news-backward', use_cache=True, chars_per_chunk=64),
WordEmbeddings(embeddings="embeddings/markup2vec.wv")
])
tagger: SequenceTagger = SequenceTagger(hidden_size=128,
embeddings=embeddings,
tag_dictionary=tag_dictionary,
tag_type=tag_type)
# 6. initialize trainer
trainer: ModelTrainer = ModelTrainer(tagger, corpus)
# 7. start training
trainer.train('resources/taggers/v003',
max_epochs=150,
mini_batch_size=64,
eval_mini_batch_size=32,
#embeddings_in_memory=False,
checkpoint=True)
if __name__ == '__main__':
print("PyTorch Version: ",torch.__version__)
print("CUDA available: ", torch.cuda.is_available() )
train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment