Skip to content

Instantly share code, notes, and snippets.

Created July 21, 2021 13:43
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
What would you like to do?
from pathlib import Path
from import (AutoModelForCausalLM, BLURR, CausalLMStrategy,
HF_LMBeforeBatchTransform, HF_CausalLMInput, HF_TextBlock, noop)
from fastai.text.all import mask2idxs, L
from import DataBlock
from import get_text_files, LMDataLoader
# Splitter for train and validatation
def _parent_idxs(items, name):
def _inner(items, name): return mask2idxs(Path(o) == name for o in items)
return [i for n in L(name) for i in _inner(items,n)]
def ParentSplitter(train_name='train', valid_name='valid'):
"Split `items` from the grand parent folder names (`train_name` and `valid_name`)."
def _inner(o, **kwargs):
return _parent_idxs(o, train_name),_parent_idxs(o, valid_name)
return _inner
# Config bits
'data_path': '/media/HD/data/pt_wiki/wiki/pt-2',
'bs': 4,
'seed': 42,
'is_lm': True,
'splitter': ParrentSplitter(),
'lang': 'pt',
'fp_16': True,
'drop_mult': 0.3,
'model_name': 'pierreguillou/gpt2-small-portuguese',
'max_seq_len': 72
# Defining blurr objects
pretrained_model_name = "pierreguillou/gpt2-small-portuguese"
hf_arch, hf_config, hf_tokenizer, hf_model = BLURR.get_hf_objects(pretrained_model_name, model_cls=AutoModelForCausalLM)
if (hf_tokenizer.pad_token is None): hf_tokenizer.pad_token = '[PAD]'
before_batch_tfm = HF_LMBeforeBatchTransform(hf_arch, hf_config, hf_tokenizer, hf_model,
blocks = [HF_TextBlock(before_batch_tfm=before_batch_tfm, input_return_type=HF_CausalLMInput), noop]
# Fastai like loaders
def text_loader_from_blocks(blocks, config, train="train", valid="valid"):
path = config["data_path"]
get_items = partial(get_text_files, folders=[train, valid])
dblock = DataBlock(blocks=blocks, get_items=get_items, splitter=config["splitter"], dl_type=LMDataLoader)
return dblock.dataloaders(path, seq_len=config["max_seq_len"], verbose=True)
lm_loader = text_loader_from_blocks(blocks, SAMPLE_CFG)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment