Skip to content

Instantly share code, notes, and snippets.

@smdshakeelhassan
Created July 28, 2021 06:35
Show Gist options
  • Save smdshakeelhassan/caae11a3879f067e2c2329f8caf60601 to your computer and use it in GitHub Desktop.
Save smdshakeelhassan/caae11a3879f067e2c2329f8caf60601 to your computer and use it in GitHub Desktop.
Medium- FARM- Document Classification- Training and Inference
import torch
from farm.modeling.tokenization import Tokenizer
from farm.data_handler.processor import TextClassificationProcessor
from farm.data_handler.data_silo import DataSilo
from farm.modeling.language_model import Roberta
from farm.modeling.prediction_head import MultiLabelTextClassificationHead
from farm.modeling.adaptive_model import AdaptiveModel
from farm.modeling.optimization import initialize_optimizer
from farm.train import Trainer
from farm.utils import MLFlowLogger
from pathlib import Path
device = torch.device("cpu")
tokenizer = Tokenizer.load(
pretrained_model_name_or_path="roberta-base",
do_lower_case=False)
label_dirs = ['politics', 'entertainment', 'sport', 'business', 'tech']
processor = TextClassificationProcessor(tokenizer=tokenizer,
max_seq_len=256,
data_dir=Path("./data_doc_class"),
label_list=label_dirs,
label_column_name="label",
metric="acc",
quote_char='"',
multilabel=True,
train_filename=Path("train.tsv"),
test_filename=Path("test.tsv"),
dev_split=0.1
)
data_silo = DataSilo(
processor=processor,
batch_size=32)
language_model = Roberta.load("roberta-base")
prediction_head = MultiLabelTextClassificationHead(num_labels=len(label_dirs))
model = AdaptiveModel(
language_model=language_model,
prediction_heads=[prediction_head],
embeds_dropout_prob=0.1,
lm_output_types=["per_sequence"],
device=device)
model, optimizer, lr_schedule = initialize_optimizer(
model=model,
learning_rate=3e-5,
device=device,
n_batches=len(data_silo.loaders["train"]),
n_epochs=3)
trainer = Trainer(
model=model,
optimizer=optimizer,
data_silo=data_silo,
epochs=3,
n_gpu=1,
lr_schedule=lr_schedule,
evaluate_every=500,
device=device)
trainer.train()
save_dir = Path("./data_doc_class/saved_model")
model.save(save_dir)
processor.save(save_dir)
with open("./bbc/bbc/business/001.txt") as f:
test_string = f.read()
from cleantext import clean
test_string = clean(test_string, fix_unicode=True, to_ascii=True, no_line_breaks=True, no_urls=True, no_emails=True, no_phone_numbers=True, no_currency_symbols=True, no_punct=True,
replace_with_punct="", replace_with_url="", replace_with_email="", replace_with_phone_number="", replace_with_number="", replace_with_digit="",
replace_with_currency_symbol="", lang="en")
from farm.infer import Inferencer
model = Inferencer.load(save_dir)
print(model.inference_from_dicts(dicts= [{"text": test_string}]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment