Skip to content

Instantly share code, notes, and snippets.

@dxs1873
Created October 16, 2020 20:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dxs1873/de4b2320ae4ad96989153b669eeb2ee0 to your computer and use it in GitHub Desktop.
Save dxs1873/de4b2320ae4ad96989153b669eeb2ee0 to your computer and use it in GitHub Desktop.
XLNet Blurr Error
import pandas as pd
import torch
from transformers import *
from fastai.text.all import *
from blurr.data.all import *
from blurr.modeling.all import *
pretrained_model_name = "xlnet-base-cased"
model_name = pretrained_model_name.split("-")[0]
root_dir = Path(os.getcwd())
resource_dir = root_dir / "resources"
data_dir = resource_dir / "data"
model_dir = resource_dir / "models"
train_path = data_dir/"train.tsv"
classifier_path = "{}_Classifier".format(model_name.upper())
train_data = pd.read_csv(train_path,header=0,index_col=0,sep="\t",skip_blank_lines=True)
task = HF_TASKS_AUTO.SequenceClassification
config = AutoConfig.from_pretrained(pretrained_model_name)
config.num_labels = len(train_data["target"].unique())
hf_arch, hf_config, hf_tokenizer, hf_model = BLURR_MODEL_HELPER.get_hf_objects(pretrained_model_name,
task=task,
config=config)
blocks = (HF_TextBlock(hf_arch=hf_arch, hf_tokenizer=hf_tokenizer), CategoryBlock)
dblock = DataBlock(blocks=blocks,
get_x=ColReader('text'), get_y=ColReader('target'),
splitter=RandomSplitter(0.2, seed=321))
dls = dblock.dataloaders(train_data, bs=32)
model = HF_BaseModelWrapper(hf_model)
learn = Learner(dls,
model,
path=resource_dir,
opt_func=partial(Adam, decouple_wd=True),
loss_func=CrossEntropyLossFlat(),
metrics=[accuracy],
cbs=[HF_BaseModelCallback],
splitter=hf_splitter).to_fp16()
learn.create_opt()
learn.freeze()
learn.fit_one_cycle(10, 1e-3)
learn.export(classifier_path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment