Skip to content

Instantly share code, notes, and snippets.

@cccntu
Created July 9, 2021 09:01
Show Gist options
  • Save cccntu/9a705f6b26834c2fc7d3f4a3de5c8130 to your computer and use it in GitHub Desktop.
Save cccntu/9a705f6b26834c2fc7d3f4a3de5c8130 to your computer and use it in GitHub Desktop.
import sys
# imports utils and imports
from src import *
from dataclasses import dataclass
from typing import Optional
from omegaconf import OmegaConf
@dataclass
class CFG:
per_device_eval_batch_size: int = 2
per_device_train_batch_size: int = 2
weight_decay: float = 0.0
learning_rate: float = 5e-5
gradient_accumulation_steps: int = 2
num_train_epochs: int = 1
max_train_steps: Optional[int] = None
lr_scheduler_type: str = "linear"
num_warmup_steps: int = 1000
fast_run_pct: Optional[float] = None
seed: int = 42
out_dir: str = "output_dir"
shuffle: bool = True
h: bool = False
num_eval: int = 3
start_year: int = 2010
end_year: int = 2017
model_name: str = "gpt2"
pretrained: bool = True
time_mode: str = "v2"
cs = ConfigStore.instance()
cs.store(name="config", node=CFG)
args = CFG()
args.time_mode = "suffix"
@hydra.main(config_name="config")
def main(cfg: CFG) -> None:
global args
args = cfg
if __name__ == "__main__":
if IS_SCRIPT:
main()
print(OmegaConf.to_yaml(args))
if args.h:
sys.exit()
set_seed(args.seed)
accelerator = Accelerator(fp16=True)
tqdm = partial(tqdm, disable=not accelerator.is_local_main_process)
# post-process args
is_local_main_process = accelerator.is_local_main_process
total_batch_size = (
args.per_device_train_batch_size
* accelerator.num_processes
* args.gradient_accumulation_steps
)
os.makedirs(args.out_dir, exist_ok=True)
class Logger:
def __init__(self, *args, **kwargs):
if is_local_main_process:
self.run = wandb.init(*args, **kwargs)
def log(self, dic):
if is_local_main_process:
wandb.log(dic)
def close(self):
if is_local_main_process:
wandb.finish()
import datatable as dt
from datasets import load_from_disk, load_dataset
from datasets import Dataset
dataset = load_from_disk("../input/processed/arxiv_processed_v2.json")
dataset = dataset["train"]
df = dt.fread("../input/processed/arxiv_id_date_split.csv").to_pandas()
dss = {}
for s in ["train", "dev", "test"]:
ids = df.query(
f'split == "{s}" and year >= {args.start_year} and year <= {args.end_year}'
).index
# dataset['time'] = time
# print(ids)
dss[s] = dataset.select(ids)
if args.fast_run_pct is not None:
pct = 1 - args.fast_run_pct
dss = {
s: ds.select(np.arange(int(len(ds) * pct), len(ds))) for s, ds in dss.items()
}
dss
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
tokenizer.pad_token = tokenizer.eos_token
from transformers import PreTrainedTokenizerBase
@dataclass
class DataCollatorForCLM:
tokenizer: PreTrainedTokenizerBase
pad_to_multiple_of: int = 16
def __call__(self, batch):
time_col = "time_ratio"
has_time = time_col in batch[0]
if has_time:
time = torch.tensor([x[time_col] for x in batch])
batch = tokenizer(
[x["abstract"] for x in batch],
truncation=True,
padding="max_length",
max_length=512,
return_tensors="pt",
pad_to_multiple_of=self.pad_to_multiple_of,
)
labels = batch["input_ids"].clone()
# force an error in no pad_token
# if self.tokenizer.pad_token_id is not None:
labels[labels == self.tokenizer.pad_token_id] = -100
batch["labels"] = labels
if has_time:
batch["time"] = time
return batch
from datetime import date
import datetime
data_collator = DataCollatorForCLM(tokenizer)
from functools import partial
start_date = datetime.date(args.start_year, 1, 1).toordinal()
end_date = datetime.date(args.end_year, 12, 31).toordinal()
def add_time_ratio(examples, start_date=start_date, end_date=end_date):
examples["time_ratio"] = [
(time - start_date) / (end_date - start_date) for time in examples["time"]
]
return examples
def process_dataset(ds):
ds = ds.remove_columns(
[x for x in ds.column_names if x not in ["abstract", "time"]]
)
ds = ds.map(add_time_ratio, batched=True)
return ds
def process_dataset_no_time(ds):
ds = ds.remove_columns([x for x in ds.column_names if x not in ["abstract"]])
return ds
# time prefix
def format_time(ordinal):
return date.fromordinal(ordinal).strftime("%B %Y: ")
def add_time_prefix(examples, start_date=start_date, end_date=end_date):
examples["abstract"] = [
format_time(t) + a for t, a in zip(examples["time"], examples["abstract"])
]
return examples
def process_dataset_add_time_prefix(ds, map_fn = add_time_prefix):
ds = ds.remove_columns(
[x for x in ds.column_names if x not in ["abstract", "time"]]
)
ds = ds.map(map_fn, batched=True)
ds = ds.remove_columns("time")
return ds
def format_time2(time, article):
if time % 2 == 0:
return article
prefix = date.fromordinal(time).strftime("%B %Y: ")
return prefix + article
def add_time_prefix2(examples):
examples["abstract"] = [
format_time2(t, a) for t, a in zip(examples["time"], examples["abstract"])
]
return examples
process_dataset_add_time_prefix2 = partial(process_dataset_add_time_prefix, map_fn=add_time_prefix2)
def format_time_suffix(ordinal):
return date.fromordinal(ordinal).strftime(" %B %Y")
def add_time_suffix(examples, start_date=start_date, end_date=end_date):
examples["abstract"] = [
a + format_time_suffix(t) for t, a in zip(examples["time"], examples["abstract"])
]
return examples
def process_dataset_add_time_suffix(ds):
ds = ds.remove_columns(
[x for x in ds.column_names if x not in ["abstract", "time"]]
)
ds = ds.map(add_time_suffix, batched=True)
ds = ds.remove_columns("time")
return ds
if args.time_mode in ["v2", "f"]:
preprocess_fn = process_dataset
elif args.time_mode in ["prefix"]:
preprocess_fn = process_dataset_add_time_prefix
elif args.time_mode in ["prefix2"]:
preprocess_fn = process_dataset_add_time_prefix2
elif args.time_mode in ["suffix"]:
preprocess_fn = process_dataset_add_time_suffix
else:
assert args.time_mode == 'none'
preprocess_fn = process_dataset_no_time
train_dataset, eval_dataset = map(
preprocess_fn,
tuple((dss["train"], dss["dev"])),
)
# for i in range(10):
# x = tokenizer(train_dataset[i]["abstract"])
# print(tokenizer.convert_ids_to_tokens(x["input_ids"])[:10])
# for m in range(1, 13):
# for y in range(2010, 2020):
# x = datetime.date(y, m, 1)
# x = format_time(x.toordinal())
# x = tokenizer(x)
# print(tokenizer.convert_ids_to_tokens(x["input_ids"]))
args.time_mode
print(f'{train_dataset[0]=}')
# # sanity check
# import seaborn as sns
# sns.displot(train_dataset["time_ratio"])
train_dataloader = DataLoader(
train_dataset,
shuffle=args.shuffle,
collate_fn=data_collator,
batch_size=args.per_device_train_batch_size,
num_workers=1,
)
eval_dataloader = DataLoader(
eval_dataset,
collate_fn=data_collator,
batch_size=args.per_device_eval_batch_size,
num_workers=1,
)
# b = next(iter(train_dataloader))
# model(**b)
if args.pretrained:
gpt2model = AutoModelForCausalLM.from_pretrained(args.model_name)
else:
config = AutoConfig.from_pretrained(args.model_name)
gpt2model = AutoModelForCausalLM.from_config(config)
if args.time_mode == "v2":
model = TimeV2GPT2LMHeadModel.from_vanilla_model(gpt2model)
elif args.time_mode == "f":
model = FGPT2LMHeadModel.from_pretrained(args.model_name)
else:
model = gpt2model
# Optimizer
# Split weights in two groups, one with weight decay and the other not.
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in model.named_parameters()
if not any(nd in n for nd in no_decay)
],
"weight_decay": args.weight_decay,
},
{
"params": [
p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)
],
"weight_decay": 0.0,
},
]
optimizer = AdamW(
optimizer_grouped_parameters, lr=args.learning_rate # * total_batch_size
)
# Prepare everything with our `accelerator`.
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader
)
def loss_fn(outputs, batch, clamp=10):
b = outputs.logits.size(0)
lm_logits = outputs.logits
labels = batch["labels"]
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
shift_mask = batch["attention_mask"][..., 1:].contiguous()
# Flatten the tokens
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
reduction="none",
).view(
b, -1
) #
loss = torch.clamp(loss, -clamp, clamp)
return (loss * shift_mask).sum() / shift_mask.sum()
# ppl = torch.exp((loss * shift_mask).sum(-1) / shift_mask.sum(-1))
# return ppl
# Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
# shorter in multiprocess)
# Scheduler and math around the number of training steps.
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / args.gradient_accumulation_steps
)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
else:
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
eval_per_n_step = args.max_train_steps // args.num_eval
scheduler = get_scheduler(
name=args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=args.num_warmup_steps,
num_training_steps=args.max_train_steps,
)
progress_bar = tqdm(range(args.max_train_steps), desc="training")
completed_steps = 0
logger = Logger(project="time_lm", config=args)
for epoch in range(args.num_train_epochs):
model.train()
for step, batch in enumerate(train_dataloader):
labels = batch.pop('labels')
outputs = model(**batch)
batch['labels'] = labels
loss = loss_fn(outputs=outputs, batch=batch)
logger.log({"loss": loss})
loss = loss / args.gradient_accumulation_steps
accelerator.backward(loss)
do_step = (
step % args.gradient_accumulation_steps == 0
or step == len(train_dataloader) - 1
)
if do_step:
# accelerator.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
progress_bar.update(1)
completed_steps += 1
else:
continue
if completed_steps > 0 and completed_steps % eval_per_n_step == 0:
model.eval()
losses = []
for step, batch in enumerate(
tqdm(eval_dataloader, desc="eval") # , leave=False)
):
with torch.no_grad():
outputs = model(**batch)
loss = outputs.loss
losses.append(
accelerator.gather(loss.repeat(args.per_device_eval_batch_size))
)
losses = torch.cat(losses)
losses = losses[: len(eval_dataset)]
perplexity = math.exp(torch.mean(losses))
logger.log({"val_perplexity": perplexity})
model.train()
# logger.info(f"epoch {epoch}: perplexity: {perplexity}")
if is_local_main_process:
save_dict = {
"epoch": epoch + 1,
"state_dict": accelerator.unwrap_model(model).state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
}
torch.save(
save_dict,
os.path.join(args.out_dir, f"checkpoint-{completed_steps}step.pt"),
)
if completed_steps >= args.max_train_steps:
break
logger.close()
if is_local_main_process and args.out_dir is not None:
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.out_dir, save_function=accelerator.save)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment