Skip to content

Instantly share code, notes, and snippets.

@tgalery
Created July 21, 2021 13:43
Show Gist options
  • Save tgalery/fa0de7b0c69ab48534b26a9151676fc1 to your computer and use it in GitHub Desktop.
Save tgalery/fa0de7b0c69ab48534b26a9151676fc1 to your computer and use it in GitHub Desktop.
from pathlib import Path
from blurr.data.language_modeling import (AutoModelForCausalLM, BLURR, CausalLMStrategy,
HF_LMBeforeBatchTransform, HF_CausalLMInput, HF_TextBlock, noop)
from fastai.text.all import mask2idxs, L
from fastai.data.block import DataBlock
from fastai.text.data import get_text_files, LMDataLoader
# Splitter for train and validatation
def _parent_idxs(items, name):
def _inner(items, name): return mask2idxs(Path(o).parent.name == name for o in items)
return [i for n in L(name) for i in _inner(items,n)]
def ParentSplitter(train_name='train', valid_name='valid'):
"Split `items` from the grand parent folder names (`train_name` and `valid_name`)."
def _inner(o, **kwargs):
return _parent_idxs(o, train_name),_parent_idxs(o, valid_name)
return _inner
# Config bits
SAMPLE_CFG = {
'data_path': '/media/HD/data/pt_wiki/wiki/pt-2',
'bs': 4,
'seed': 42,
'is_lm': True,
'splitter': ParrentSplitter(),
'lang': 'pt',
'fp_16': True,
'drop_mult': 0.3,
'model_name': 'pierreguillou/gpt2-small-portuguese',
'max_seq_len': 72
}
# Defining blurr objects
pretrained_model_name = "pierreguillou/gpt2-small-portuguese"
hf_arch, hf_config, hf_tokenizer, hf_model = BLURR.get_hf_objects(pretrained_model_name, model_cls=AutoModelForCausalLM)
if (hf_tokenizer.pad_token is None): hf_tokenizer.pad_token = '[PAD]'
before_batch_tfm = HF_LMBeforeBatchTransform(hf_arch, hf_config, hf_tokenizer, hf_model,
lm_strategy_cls=CausalLMStrategy)
blocks = [HF_TextBlock(before_batch_tfm=before_batch_tfm, input_return_type=HF_CausalLMInput), noop]
# Fastai like loaders
def text_loader_from_blocks(blocks, config, train="train", valid="valid"):
path = config["data_path"]
get_items = partial(get_text_files, folders=[train, valid])
dblock = DataBlock(blocks=blocks, get_items=get_items, splitter=config["splitter"], dl_type=LMDataLoader)
return dblock.dataloaders(path, seq_len=config["max_seq_len"], verbose=True)
lm_loader = text_loader_from_blocks(blocks, SAMPLE_CFG)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment