Created
July 9, 2021 09:01
-
-
Save cccntu/9a705f6b26834c2fc7d3f4a3de5c8130 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
# imports utils and imports | |
from src import * | |
from dataclasses import dataclass | |
from typing import Optional | |
from omegaconf import OmegaConf | |
@dataclass | |
class CFG: | |
per_device_eval_batch_size: int = 2 | |
per_device_train_batch_size: int = 2 | |
weight_decay: float = 0.0 | |
learning_rate: float = 5e-5 | |
gradient_accumulation_steps: int = 2 | |
num_train_epochs: int = 1 | |
max_train_steps: Optional[int] = None | |
lr_scheduler_type: str = "linear" | |
num_warmup_steps: int = 1000 | |
fast_run_pct: Optional[float] = None | |
seed: int = 42 | |
out_dir: str = "output_dir" | |
shuffle: bool = True | |
h: bool = False | |
num_eval: int = 3 | |
start_year: int = 2010 | |
end_year: int = 2017 | |
model_name: str = "gpt2" | |
pretrained: bool = True | |
time_mode: str = "v2" | |
cs = ConfigStore.instance() | |
cs.store(name="config", node=CFG) | |
args = CFG() | |
args.time_mode = "suffix" | |
@hydra.main(config_name="config") | |
def main(cfg: CFG) -> None: | |
global args | |
args = cfg | |
if __name__ == "__main__": | |
if IS_SCRIPT: | |
main() | |
print(OmegaConf.to_yaml(args)) | |
if args.h: | |
sys.exit() | |
set_seed(args.seed) | |
accelerator = Accelerator(fp16=True) | |
tqdm = partial(tqdm, disable=not accelerator.is_local_main_process) | |
# post-process args | |
is_local_main_process = accelerator.is_local_main_process | |
total_batch_size = ( | |
args.per_device_train_batch_size | |
* accelerator.num_processes | |
* args.gradient_accumulation_steps | |
) | |
os.makedirs(args.out_dir, exist_ok=True) | |
class Logger: | |
def __init__(self, *args, **kwargs): | |
if is_local_main_process: | |
self.run = wandb.init(*args, **kwargs) | |
def log(self, dic): | |
if is_local_main_process: | |
wandb.log(dic) | |
def close(self): | |
if is_local_main_process: | |
wandb.finish() | |
import datatable as dt | |
from datasets import load_from_disk, load_dataset | |
from datasets import Dataset | |
dataset = load_from_disk("../input/processed/arxiv_processed_v2.json") | |
dataset = dataset["train"] | |
df = dt.fread("../input/processed/arxiv_id_date_split.csv").to_pandas() | |
dss = {} | |
for s in ["train", "dev", "test"]: | |
ids = df.query( | |
f'split == "{s}" and year >= {args.start_year} and year <= {args.end_year}' | |
).index | |
# dataset['time'] = time | |
# print(ids) | |
dss[s] = dataset.select(ids) | |
if args.fast_run_pct is not None: | |
pct = 1 - args.fast_run_pct | |
dss = { | |
s: ds.select(np.arange(int(len(ds) * pct), len(ds))) for s, ds in dss.items() | |
} | |
dss | |
tokenizer = AutoTokenizer.from_pretrained(args.model_name) | |
tokenizer.pad_token = tokenizer.eos_token | |
from transformers import PreTrainedTokenizerBase | |
@dataclass | |
class DataCollatorForCLM: | |
tokenizer: PreTrainedTokenizerBase | |
pad_to_multiple_of: int = 16 | |
def __call__(self, batch): | |
time_col = "time_ratio" | |
has_time = time_col in batch[0] | |
if has_time: | |
time = torch.tensor([x[time_col] for x in batch]) | |
batch = tokenizer( | |
[x["abstract"] for x in batch], | |
truncation=True, | |
padding="max_length", | |
max_length=512, | |
return_tensors="pt", | |
pad_to_multiple_of=self.pad_to_multiple_of, | |
) | |
labels = batch["input_ids"].clone() | |
# force an error in no pad_token | |
# if self.tokenizer.pad_token_id is not None: | |
labels[labels == self.tokenizer.pad_token_id] = -100 | |
batch["labels"] = labels | |
if has_time: | |
batch["time"] = time | |
return batch | |
from datetime import date | |
import datetime | |
data_collator = DataCollatorForCLM(tokenizer) | |
from functools import partial | |
start_date = datetime.date(args.start_year, 1, 1).toordinal() | |
end_date = datetime.date(args.end_year, 12, 31).toordinal() | |
def add_time_ratio(examples, start_date=start_date, end_date=end_date): | |
examples["time_ratio"] = [ | |
(time - start_date) / (end_date - start_date) for time in examples["time"] | |
] | |
return examples | |
def process_dataset(ds): | |
ds = ds.remove_columns( | |
[x for x in ds.column_names if x not in ["abstract", "time"]] | |
) | |
ds = ds.map(add_time_ratio, batched=True) | |
return ds | |
def process_dataset_no_time(ds): | |
ds = ds.remove_columns([x for x in ds.column_names if x not in ["abstract"]]) | |
return ds | |
# time prefix | |
def format_time(ordinal): | |
return date.fromordinal(ordinal).strftime("%B %Y: ") | |
def add_time_prefix(examples, start_date=start_date, end_date=end_date): | |
examples["abstract"] = [ | |
format_time(t) + a for t, a in zip(examples["time"], examples["abstract"]) | |
] | |
return examples | |
def process_dataset_add_time_prefix(ds, map_fn = add_time_prefix): | |
ds = ds.remove_columns( | |
[x for x in ds.column_names if x not in ["abstract", "time"]] | |
) | |
ds = ds.map(map_fn, batched=True) | |
ds = ds.remove_columns("time") | |
return ds | |
def format_time2(time, article): | |
if time % 2 == 0: | |
return article | |
prefix = date.fromordinal(time).strftime("%B %Y: ") | |
return prefix + article | |
def add_time_prefix2(examples): | |
examples["abstract"] = [ | |
format_time2(t, a) for t, a in zip(examples["time"], examples["abstract"]) | |
] | |
return examples | |
process_dataset_add_time_prefix2 = partial(process_dataset_add_time_prefix, map_fn=add_time_prefix2) | |
def format_time_suffix(ordinal): | |
return date.fromordinal(ordinal).strftime(" %B %Y") | |
def add_time_suffix(examples, start_date=start_date, end_date=end_date): | |
examples["abstract"] = [ | |
a + format_time_suffix(t) for t, a in zip(examples["time"], examples["abstract"]) | |
] | |
return examples | |
def process_dataset_add_time_suffix(ds): | |
ds = ds.remove_columns( | |
[x for x in ds.column_names if x not in ["abstract", "time"]] | |
) | |
ds = ds.map(add_time_suffix, batched=True) | |
ds = ds.remove_columns("time") | |
return ds | |
if args.time_mode in ["v2", "f"]: | |
preprocess_fn = process_dataset | |
elif args.time_mode in ["prefix"]: | |
preprocess_fn = process_dataset_add_time_prefix | |
elif args.time_mode in ["prefix2"]: | |
preprocess_fn = process_dataset_add_time_prefix2 | |
elif args.time_mode in ["suffix"]: | |
preprocess_fn = process_dataset_add_time_suffix | |
else: | |
assert args.time_mode == 'none' | |
preprocess_fn = process_dataset_no_time | |
train_dataset, eval_dataset = map( | |
preprocess_fn, | |
tuple((dss["train"], dss["dev"])), | |
) | |
# for i in range(10): | |
# x = tokenizer(train_dataset[i]["abstract"]) | |
# print(tokenizer.convert_ids_to_tokens(x["input_ids"])[:10]) | |
# for m in range(1, 13): | |
# for y in range(2010, 2020): | |
# x = datetime.date(y, m, 1) | |
# x = format_time(x.toordinal()) | |
# x = tokenizer(x) | |
# print(tokenizer.convert_ids_to_tokens(x["input_ids"])) | |
args.time_mode | |
print(f'{train_dataset[0]=}') | |
# # sanity check | |
# import seaborn as sns | |
# sns.displot(train_dataset["time_ratio"]) | |
train_dataloader = DataLoader( | |
train_dataset, | |
shuffle=args.shuffle, | |
collate_fn=data_collator, | |
batch_size=args.per_device_train_batch_size, | |
num_workers=1, | |
) | |
eval_dataloader = DataLoader( | |
eval_dataset, | |
collate_fn=data_collator, | |
batch_size=args.per_device_eval_batch_size, | |
num_workers=1, | |
) | |
# b = next(iter(train_dataloader)) | |
# model(**b) | |
if args.pretrained: | |
gpt2model = AutoModelForCausalLM.from_pretrained(args.model_name) | |
else: | |
config = AutoConfig.from_pretrained(args.model_name) | |
gpt2model = AutoModelForCausalLM.from_config(config) | |
if args.time_mode == "v2": | |
model = TimeV2GPT2LMHeadModel.from_vanilla_model(gpt2model) | |
elif args.time_mode == "f": | |
model = FGPT2LMHeadModel.from_pretrained(args.model_name) | |
else: | |
model = gpt2model | |
# Optimizer | |
# Split weights in two groups, one with weight decay and the other not. | |
no_decay = ["bias", "LayerNorm.weight"] | |
optimizer_grouped_parameters = [ | |
{ | |
"params": [ | |
p | |
for n, p in model.named_parameters() | |
if not any(nd in n for nd in no_decay) | |
], | |
"weight_decay": args.weight_decay, | |
}, | |
{ | |
"params": [ | |
p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) | |
], | |
"weight_decay": 0.0, | |
}, | |
] | |
optimizer = AdamW( | |
optimizer_grouped_parameters, lr=args.learning_rate # * total_batch_size | |
) | |
# Prepare everything with our `accelerator`. | |
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare( | |
model, optimizer, train_dataloader, eval_dataloader | |
) | |
def loss_fn(outputs, batch, clamp=10): | |
b = outputs.logits.size(0) | |
lm_logits = outputs.logits | |
labels = batch["labels"] | |
shift_logits = lm_logits[..., :-1, :].contiguous() | |
shift_labels = labels[..., 1:].contiguous() | |
shift_mask = batch["attention_mask"][..., 1:].contiguous() | |
# Flatten the tokens | |
loss = F.cross_entropy( | |
shift_logits.view(-1, shift_logits.size(-1)), | |
shift_labels.view(-1), | |
reduction="none", | |
).view( | |
b, -1 | |
) # | |
loss = torch.clamp(loss, -clamp, clamp) | |
return (loss * shift_mask).sum() / shift_mask.sum() | |
# ppl = torch.exp((loss * shift_mask).sum(-1) / shift_mask.sum(-1)) | |
# return ppl | |
# Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be | |
# shorter in multiprocess) | |
# Scheduler and math around the number of training steps. | |
num_update_steps_per_epoch = math.ceil( | |
len(train_dataloader) / args.gradient_accumulation_steps | |
) | |
if args.max_train_steps is None: | |
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | |
else: | |
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | |
eval_per_n_step = args.max_train_steps // args.num_eval | |
scheduler = get_scheduler( | |
name=args.lr_scheduler_type, | |
optimizer=optimizer, | |
num_warmup_steps=args.num_warmup_steps, | |
num_training_steps=args.max_train_steps, | |
) | |
progress_bar = tqdm(range(args.max_train_steps), desc="training") | |
completed_steps = 0 | |
logger = Logger(project="time_lm", config=args) | |
for epoch in range(args.num_train_epochs): | |
model.train() | |
for step, batch in enumerate(train_dataloader): | |
labels = batch.pop('labels') | |
outputs = model(**batch) | |
batch['labels'] = labels | |
loss = loss_fn(outputs=outputs, batch=batch) | |
logger.log({"loss": loss}) | |
loss = loss / args.gradient_accumulation_steps | |
accelerator.backward(loss) | |
do_step = ( | |
step % args.gradient_accumulation_steps == 0 | |
or step == len(train_dataloader) - 1 | |
) | |
if do_step: | |
# accelerator.clip_grad_norm_(model.parameters(), 1.0) | |
optimizer.step() | |
scheduler.step() | |
optimizer.zero_grad() | |
progress_bar.update(1) | |
completed_steps += 1 | |
else: | |
continue | |
if completed_steps > 0 and completed_steps % eval_per_n_step == 0: | |
model.eval() | |
losses = [] | |
for step, batch in enumerate( | |
tqdm(eval_dataloader, desc="eval") # , leave=False) | |
): | |
with torch.no_grad(): | |
outputs = model(**batch) | |
loss = outputs.loss | |
losses.append( | |
accelerator.gather(loss.repeat(args.per_device_eval_batch_size)) | |
) | |
losses = torch.cat(losses) | |
losses = losses[: len(eval_dataset)] | |
perplexity = math.exp(torch.mean(losses)) | |
logger.log({"val_perplexity": perplexity}) | |
model.train() | |
# logger.info(f"epoch {epoch}: perplexity: {perplexity}") | |
if is_local_main_process: | |
save_dict = { | |
"epoch": epoch + 1, | |
"state_dict": accelerator.unwrap_model(model).state_dict(), | |
"optimizer": optimizer.state_dict(), | |
"scheduler": scheduler.state_dict(), | |
} | |
torch.save( | |
save_dict, | |
os.path.join(args.out_dir, f"checkpoint-{completed_steps}step.pt"), | |
) | |
if completed_steps >= args.max_train_steps: | |
break | |
logger.close() | |
if is_local_main_process and args.out_dir is not None: | |
accelerator.wait_for_everyone() | |
unwrapped_model = accelerator.unwrap_model(model) | |
unwrapped_model.save_pretrained(args.out_dir, save_function=accelerator.save) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment