Last active
January 11, 2023 03:24
-
-
Save buttercutter/34597783d681ce6407ff26ec3b76e56e to your computer and use it in GitHub Desktop.
flax training script with gradient accumulation trick
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# coding=utf-8 | |
# Copyright 2021 The HuggingFace Team All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
Fine-tuning the library models for summarization. | |
""" | |
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. | |
import json | |
import logging | |
import math | |
import os | |
import sys | |
import time | |
from dataclasses import asdict, dataclass, field | |
from enum import Enum | |
from functools import partial | |
from pathlib import Path | |
from typing import Callable, Optional | |
import datasets | |
import nltk # Here to have a nice missing dependency error message early on | |
import numpy as np | |
from datasets import Dataset, load_dataset | |
from tqdm import tqdm | |
import evaluate | |
import jax | |
import jax.numpy as jnp | |
import optax | |
import transformers | |
from filelock import FileLock | |
from flax import jax_utils, traverse_util | |
from flax.jax_utils import pad_shard_unpad, unreplicate | |
from flax.training import train_state | |
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key | |
from huggingface_hub import Repository | |
from transformers import ( | |
CONFIG_MAPPING, | |
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, | |
AutoConfig, | |
AutoTokenizer, | |
FlaxAutoModelForSeq2SeqLM, | |
HfArgumentParser, | |
is_tensorboard_available, | |
) | |
from transformers.utils import get_full_repo_name, is_offline_mode, send_example_telemetry | |
import pandas as pd | |
logger = logging.getLogger(__name__) | |
try: | |
nltk.data.find("tokenizers/punkt") | |
except (LookupError, OSError): | |
if is_offline_mode(): | |
raise LookupError( | |
"Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files" | |
) | |
with FileLock(".lock") as lock: | |
nltk.download("punkt", quiet=True) | |
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys()) | |
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) | |
@dataclass | |
class TrainingArguments: | |
output_dir: str = field( | |
metadata={"help": "The output directory where the model predictions and checkpoints will be written."}, | |
) | |
overwrite_output_dir: bool = field( | |
default=False, | |
metadata={ | |
"help": ( | |
"Overwrite the content of the output directory. " | |
"Use this to continue training if output_dir points to a checkpoint directory." | |
) | |
}, | |
) | |
do_train: bool = field(default=False, metadata={"help": "Whether to run training."}) | |
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."}) | |
do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."}) | |
per_device_train_batch_size: int = field( | |
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."} | |
) | |
per_device_eval_batch_size: int = field( | |
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."} | |
) | |
# buttercuter added 12/6/22 | |
per_device_gradient_accumulation_steps: int = field( | |
default=1, | |
metadata={ | |
"help": "Number of updates steps to accumulate before performing a backward/update pass." | |
}, | |
) | |
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."}) | |
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."}) | |
adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"}) | |
adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"}) | |
adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}) | |
label_smoothing_factor: float = field( | |
default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."} | |
) | |
adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."}) | |
num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."}) | |
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."}) | |
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."}) | |
save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."}) | |
eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."}) | |
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."}) | |
push_to_hub: bool = field( | |
default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."} | |
) | |
hub_model_id: str = field( | |
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."} | |
) | |
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."}) | |
gradient_checkpointing: bool = field( | |
default=False, | |
metadata={ | |
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass." | |
}, | |
) | |
def __post_init__(self): | |
if self.output_dir is not None: | |
self.output_dir = os.path.expanduser(self.output_dir) | |
def to_dict(self): | |
""" | |
Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates | |
the token values by removing their value. | |
""" | |
d = asdict(self) | |
for k, v in d.items(): | |
if isinstance(v, Enum): | |
d[k] = v.value | |
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum): | |
d[k] = [x.value for x in v] | |
if k.endswith("_token"): | |
d[k] = f"<{k.upper()}>" | |
return d | |
@dataclass | |
class ModelArguments: | |
""" | |
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. | |
""" | |
model_name_or_path: Optional[str] = field( | |
default=None, | |
metadata={ | |
"help": ( | |
"The model checkpoint for weights initialization.Don't set if you want to train a model from scratch." | |
) | |
}, | |
) | |
model_type: Optional[str] = field( | |
default=None, | |
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, | |
) | |
config_name: Optional[str] = field( | |
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} | |
) | |
tokenizer_name: Optional[str] = field( | |
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} | |
) | |
cache_dir: Optional[str] = field( | |
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} | |
) | |
use_fast_tokenizer: bool = field( | |
default=True, | |
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, | |
) | |
dtype: Optional[str] = field( | |
default="float32", | |
metadata={ | |
"help": ( | |
"Floating-point format in which the model weights should be initialized and trained. Choose one of" | |
" `[float32, float16, bfloat16]`." | |
) | |
}, | |
) | |
use_auth_token: bool = field( | |
default=False, | |
metadata={ | |
"help": ( | |
"Will use the token generated when running `huggingface-cli login` (necessary to use this script " | |
"with private models)." | |
) | |
}, | |
) | |
freeze_feature_encoder: bool = field( | |
default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."} | |
) | |
@dataclass | |
class DataTrainingArguments: | |
""" | |
Arguments pertaining to what data we are going to input our model for training and eval. | |
""" | |
dataset_name: Optional[str] = field( | |
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} | |
) | |
dataset_config_name: Optional[str] = field( | |
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} | |
) | |
text_column: Optional[str] = field( | |
default=None, | |
metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."}, | |
) | |
summary_column: Optional[str] = field( | |
default=None, | |
metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."}, | |
) | |
train_file: Optional[str] = field(default="train_small_20.csv", metadata={"help": "The input training data file (a text file)."}) | |
validation_file: Optional[str] = field( | |
default="eval_small_10.csv", | |
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, | |
) | |
test_file: Optional[str] = field( | |
default="eval_small_10.csv", | |
metadata={"help": "An optional input predict data file to do prediction on (a text file)."}, | |
) | |
max_source_length: Optional[int] = field( | |
default=1024, | |
metadata={ | |
"help": ( | |
"The maximum total input sequence length after tokenization. Sequences longer " | |
"than this will be truncated, sequences shorter will be padded." | |
) | |
}, | |
) | |
max_target_length: Optional[int] = field( | |
default=128, | |
metadata={ | |
"help": ( | |
"The maximum total sequence length for target text after tokenization. Sequences longer " | |
"than this will be truncated, sequences shorter will be padded." | |
) | |
}, | |
) | |
val_max_target_length: Optional[int] = field( | |
default=None, | |
metadata={ | |
"help": ( | |
"The maximum total sequence length for validation target text after tokenization. Sequences longer " | |
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." | |
"This argument is also used to override the `max_length` param of `model.generate`, which is used " | |
"during evaluation." | |
) | |
}, | |
) | |
max_train_samples: Optional[int] = field( | |
default=None, | |
metadata={ | |
"help": ( | |
"For debugging purposes or quicker training, truncate the number of training examples to this " | |
"value if set." | |
) | |
}, | |
) | |
max_eval_samples: Optional[int] = field( | |
default=None, | |
metadata={ | |
"help": ( | |
"For debugging purposes or quicker training, truncate the number of evaluation examples to this " | |
"value if set." | |
) | |
}, | |
) | |
max_predict_samples: Optional[int] = field( | |
default=None, | |
metadata={ | |
"help": ( | |
"For debugging purposes or quicker training, truncate the number of prediction examples to this " | |
"value if set." | |
) | |
}, | |
) | |
preprocessing_num_workers: Optional[int] = field( | |
default=None, | |
metadata={"help": "The number of processes to use for the preprocessing."}, | |
) | |
source_prefix: Optional[str] = field( | |
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."} | |
) | |
predict_with_generate: bool = field( | |
default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."} | |
) | |
num_beams: Optional[int] = field( | |
default=None, | |
metadata={ | |
"help": ( | |
"Number of beams to use for evaluation. This argument will be passed to `model.generate`, " | |
"which is used during evaluation." | |
) | |
}, | |
) | |
overwrite_cache: bool = field( | |
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} | |
) | |
def __post_init__(self): | |
if self.dataset_name is None and self.train_file is None and self.validation_file is None: | |
raise ValueError("Need either a dataset name or a training/validation file.") | |
else: | |
if self.train_file is not None: | |
extension = self.train_file.split(".")[-1] | |
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." | |
if self.validation_file is not None: | |
extension = self.validation_file.split(".")[-1] | |
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." | |
if self.val_max_target_length is None: | |
self.val_max_target_length = self.max_target_length | |
summarization_name_mapping = { | |
"amazon_reviews_multi": ("review_body", "review_title"), | |
"big_patent": ("description", "abstract"), | |
"cnn_dailymail": ("article", "highlights"), | |
"orange_sum": ("text", "summary"), | |
"pn_summary": ("article", "summary"), | |
"psc": ("extract_text", "summary_text"), | |
"samsum": ("dialogue", "summary"), | |
"thaisum": ("body", "summary"), | |
"xglue": ("news_body", "news_title"), | |
"xsum": ("document", "summary"), | |
"wiki_summary": ("article", "highlights"), | |
} | |
class TrainState(train_state.TrainState): | |
dropout_rng: jnp.ndarray | |
def replicate(self): | |
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng)) | |
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True): | |
""" | |
Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete, | |
and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`. | |
""" | |
if shuffle: | |
batch_idx = jax.random.permutation(rng, len(dataset)) | |
batch_idx = np.asarray(batch_idx) | |
else: | |
batch_idx = np.arange(len(dataset)) | |
if drop_last: | |
steps_per_epoch = len(dataset) // batch_size | |
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch. | |
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size)) | |
else: | |
steps_per_epoch = math.ceil(len(dataset) / batch_size) | |
batch_idx = np.array_split(batch_idx, steps_per_epoch) | |
for idx in batch_idx: | |
batch = dataset[idx] | |
batch = {k: np.array(v) for k, v in batch.items()} | |
yield batch | |
def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step): | |
summary_writer.scalar("train_time", train_time, step) | |
train_metrics = get_metrics(train_metrics) | |
for key, vals in train_metrics.items(): | |
tag = f"train_{key}" | |
for i, val in enumerate(vals): | |
summary_writer.scalar(tag, val, step - len(vals) + i + 1) | |
for metric_name, value in eval_metrics.items(): | |
summary_writer.scalar(f"eval_{metric_name}", value, step) | |
def create_learning_rate_fn( | |
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float | |
) -> Callable[[int], jnp.array]: | |
"""Returns a linear warmup, linear_decay learning rate function.""" | |
steps_per_epoch = train_ds_size // train_batch_size | |
num_train_steps = steps_per_epoch * num_train_epochs | |
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) | |
decay_fn = optax.linear_schedule( | |
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps | |
) | |
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) | |
return schedule_fn | |
def to_fp32(t): | |
return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t) | |
def to_bf16(t): | |
return jax.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t) | |
def main(): | |
# See all possible arguments in src/transformers/training_args.py | |
# or by passing the --help flag to this script. | |
# We now keep distinct sets of args, for a cleaner separation of concerns. | |
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) | |
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): | |
# If we pass only one argument to the script and it's the path to a json file, | |
# let's parse it to get our arguments. | |
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) | |
else: | |
model_args, data_args, training_args = parser.parse_args_into_dataclasses() | |
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The | |
# information sent is the one passed as arguments along with your Python/PyTorch versions. | |
send_example_telemetry("run_summarization", model_args, data_args, framework="flax") | |
if ( | |
os.path.exists(training_args.output_dir) | |
and os.listdir(training_args.output_dir) | |
and training_args.do_train | |
and not training_args.overwrite_output_dir | |
): | |
raise ValueError( | |
f"Output directory ({training_args.output_dir}) already exists and is not empty." | |
"Use --overwrite_output_dir to overcome." | |
) | |
# Make one log on every process with the configuration for debugging. | |
logging.basicConfig( | |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
datefmt="%m/%d/%Y %H:%M:%S", | |
level=logging.INFO, | |
) | |
# Setup logging, we only want one process per machine to log things on the screen. | |
logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) | |
if jax.process_index() == 0: | |
datasets.utils.logging.set_verbosity_warning() | |
transformers.utils.logging.set_verbosity_info() | |
else: | |
datasets.utils.logging.set_verbosity_error() | |
transformers.utils.logging.set_verbosity_error() | |
# Set the verbosity to info of the Transformers logger (on main process only): | |
logger.info(f"Training/evaluation parameters {training_args}") | |
to_dtype = to_bf16 if model_args.dtype=="bfloat16" else to_fp32 | |
# Handle the repository creation | |
if training_args.push_to_hub: | |
if training_args.hub_model_id is None: | |
repo_name = get_full_repo_name( | |
Path(training_args.output_dir).absolute().name, token=training_args.hub_token | |
) | |
else: | |
repo_name = training_args.hub_model_id | |
repo = Repository(training_args.output_dir, clone_from=repo_name) | |
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) | |
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ | |
# (the dataset will be downloaded automatically from the datasets Hub). | |
# | |
# For CSV/JSON files this script will use the first column for the full texts and the second column for the | |
# summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments). | |
# | |
if data_args.dataset_name is not None: | |
# Downloading and loading a dataset from the hub. | |
dataset = load_dataset( | |
data_args.dataset_name, | |
data_args.dataset_config_name, | |
cache_dir=model_args.cache_dir, | |
keep_in_memory=False, | |
use_auth_token=True if model_args.use_auth_token else None, | |
) | |
else: | |
os.system("wget -O train_small_20.csv https://pastebin.com/raw/kRdkD7bm") | |
os.system("wget -O eval_small_10.csv https://pastebin.com/raw/bzuCNPHZ") | |
os.system("wget -O eval_arc_test_dataset.csv https://pastebin.com/raw/u8DyP7dT") | |
# https://stackoverflow.com/a/30529461 | |
# needs to get the columns of 'prompt' and 'correct_answer' only | |
# for training and validation respectively | |
from petl import fromcsv, look, cut, tocsv | |
#Load the table | |
table1_train = fromcsv(data_args.train_file) # train_small_20.csv | |
# Alter the colums | |
table2_train = cut(table1_train, 'prompt', 'correct_answer') | |
#have a quick look to make sure things are ok. Prints a nicely formatted table to your console | |
print(look(table2_train)) | |
# Save to new file | |
tocsv(table2_train, 'train.csv') | |
#Load the table | |
table1_validate = fromcsv(data_args.validation_file) # eval_small_10.csv | |
# Alter the colums | |
table2_validate = cut(table1_validate, 'prompt', 'correct_answer') | |
#have a quick look to make sure things are ok. Prints a nicely formatted table to your console | |
print(look(table2_validate)) | |
# Save to new file | |
tocsv(table2_validate, 'validate.csv') | |
#Load the table | |
table1_test = fromcsv(data_args.test_file) # eval_arc_test_dataset.csv | |
# Alter the colums | |
table2_test = cut(table1_test, 'prompt', 'correct_answer') | |
#have a quick look to make sure things are ok. Prints a nicely formatted table to your console | |
print(look(table2_test)) | |
# Save to new file | |
tocsv(table2_test, 'test.csv') | |
data_files = {} | |
if data_args.train_file is not None: | |
data_files["train"] = 'train.csv' #data_args.train_file | |
extension = data_args.train_file.split(".")[-1] | |
if data_args.validation_file is not None: | |
data_files["validation"] = 'validate.csv' #data_args.validation_file | |
extension = data_args.validation_file.split(".")[-1] | |
if data_args.test_file is not None: | |
data_files["test"] = 'test.csv' #data_args.test_file | |
extension = data_args.test_file.split(".")[-1] | |
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at | |
# https://huggingface.co/docs/datasets/loading_datasets.html | |
dataset = load_dataset( | |
extension, | |
data_files=data_files, | |
cache_dir=model_args.cache_dir, | |
keep_in_memory=True, # https://huggingface.co/docs/datasets/cache#enable-or-disable-caching | |
use_auth_token=True if model_args.use_auth_token else None, | |
) | |
# Load pretrained model and tokenizer | |
if model_args.config_name: | |
config = AutoConfig.from_pretrained( | |
model_args.config_name, | |
cache_dir=model_args.cache_dir, | |
use_auth_token=True if model_args.use_auth_token else None, | |
) | |
elif model_args.model_name_or_path: | |
config = AutoConfig.from_pretrained( | |
model_args.model_name_or_path, | |
cache_dir=model_args.cache_dir, | |
use_auth_token=True if model_args.use_auth_token else None, | |
) | |
else: | |
config = CONFIG_MAPPING[model_args.model_type]() | |
logger.warning("You are instantiating a new config instance from scratch.") | |
if model_args.tokenizer_name: | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_args.tokenizer_name, | |
cache_dir=model_args.cache_dir, | |
use_fast=model_args.use_fast_tokenizer, | |
use_auth_token=True if model_args.use_auth_token else None, | |
) | |
elif model_args.model_name_or_path: | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_args.model_name_or_path, | |
cache_dir=model_args.cache_dir, | |
use_fast=model_args.use_fast_tokenizer, | |
use_auth_token=True if model_args.use_auth_token else None, | |
) | |
else: | |
raise ValueError( | |
"You are instantiating a new tokenizer from scratch. This is not supported by this script." | |
"You can do it from another script, save it, and load it from here, using --tokenizer_name." | |
) | |
if model_args.model_name_or_path: | |
model = FlaxAutoModelForSeq2SeqLM.from_pretrained( | |
model_args.model_name_or_path, | |
config=config, | |
seed=training_args.seed, | |
dtype=getattr(jnp, model_args.dtype), | |
use_auth_token=True if model_args.use_auth_token else None, | |
) | |
else: | |
model = FlaxAutoModelForSeq2SeqLM.from_config( | |
config, | |
seed=training_args.seed, | |
dtype=getattr(jnp, model_args.dtype), | |
) | |
if training_args.gradient_checkpointing: | |
model.enable_gradient_checkpointing() | |
if model.config.decoder_start_token_id is None: | |
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") | |
prefix = data_args.source_prefix if data_args.source_prefix is not None else "" | |
# Preprocessing the datasets. | |
# We need to tokenize inputs and targets. | |
if training_args.do_train: | |
column_names = dataset["train"].column_names | |
elif training_args.do_eval: | |
column_names = dataset["validation"].column_names | |
elif training_args.do_predict: | |
column_names = dataset["test"].column_names | |
else: | |
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") | |
return | |
# Get the column names for input/target. | |
dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None) | |
if data_args.text_column is None: | |
text_column = dataset_columns[0] if dataset_columns is not None else column_names[0] | |
else: | |
text_column = data_args.text_column | |
if text_column not in column_names: | |
raise ValueError( | |
f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}" | |
) | |
if data_args.summary_column is None: | |
summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1] | |
else: | |
summary_column = data_args.summary_column | |
if summary_column not in column_names: | |
raise ValueError( | |
f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}" | |
) | |
# Temporarily set max_target_length and max_source_lengthfor training. | |
max_target_length = data_args.max_target_length | |
max_source_length = data_args.max_source_length | |
# In Flax, for seq2seq models we need to pass `decoder_input_ids` | |
# as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here | |
# for that dynamically import the `shift_tokens_right` function from the model file | |
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"]) | |
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") | |
# train_df_filtered = \ | |
# filter_file_for_max_tokens_and_add_tasks(model, tokenizer, file_name="train.csv", max_input_length=max_source_length, max_output_length=max_target_length, | |
# experiment_folder="./experiment/", write_to_file="train_filtered.csv", filter_by_token_count=True, shuffle=True, | |
# add_classification_task=False, add_correctness_task_ratio=0, add_fix_task_ratio=0) | |
# eval_df_filtered = \ | |
# filter_file_for_max_tokens_and_add_tasks(model, tokenizer, file_name="eval.csv", max_input_length=max_source_length, max_output_length=max_target_length, | |
# experiment_folder="./experiment/", write_to_file="eval_filtered.csv", filter_by_token_count=True, shuffle=True, | |
# add_classification_task=False, add_correctness_task_ratio=0, add_fix_task_ratio=0) | |
# test_df_filtered = \ | |
# filter_file_for_max_tokens_and_add_tasks(model, tokenizer, file_name="test.csv", max_input_length=max_source_length, max_output_length=max_target_length, | |
# experiment_folder="./experiment/", write_to_file="test_filtered.csv", filter_by_token_count=True, shuffle=True, | |
# add_classification_task=False, add_correctness_task_ratio=0, add_fix_task_ratio=0) | |
# Setting padding="max_length" as we need fixed length inputs for jitted functions | |
def preprocess_function(examples): | |
inputs = examples[text_column] | |
targets = examples[summary_column] | |
inputs = [prefix + inp for inp in inputs] | |
model_inputs = tokenizer( | |
inputs, max_length=data_args.max_source_length, padding="max_length", truncation=True, return_tensors="np" | |
) | |
# Setup the tokenizer for targets | |
labels = tokenizer( | |
text=targets, | |
max_length=max_target_length, | |
padding="max_length", | |
truncation=True, | |
return_tensors="np", | |
) | |
model_inputs["labels"] = labels["input_ids"] | |
decoder_input_ids = shift_tokens_right_fn( | |
labels["input_ids"], config.pad_token_id, config.decoder_start_token_id | |
) | |
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids) | |
# We need decoder_attention_mask so we can ignore pad tokens from loss | |
model_inputs["decoder_attention_mask"] = labels["attention_mask"] | |
return model_inputs | |
if training_args.do_train: | |
if "train" not in dataset: | |
raise ValueError("--do_train requires a train dataset") | |
train_dataset = dataset["train"] | |
if data_args.max_train_samples is not None: | |
max_train_samples = min(len(train_dataset), data_args.max_train_samples) | |
train_dataset = train_dataset.select(range(max_train_samples)) | |
train_dataset = train_dataset.map( | |
preprocess_function, | |
batched=True, | |
num_proc=data_args.preprocessing_num_workers, | |
remove_columns=column_names, | |
load_from_cache_file=not data_args.overwrite_cache, | |
desc="Running tokenizer on train dataset", | |
) | |
if training_args.do_eval: | |
max_target_length = data_args.val_max_target_length | |
if "validation" not in dataset: | |
raise ValueError("--do_eval requires a validation dataset") | |
eval_dataset = dataset["validation"] | |
if data_args.max_eval_samples is not None: | |
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) | |
eval_dataset = eval_dataset.select(range(max_eval_samples)) | |
eval_dataset = eval_dataset.map( | |
preprocess_function, | |
batched=True, | |
num_proc=data_args.preprocessing_num_workers, | |
remove_columns=column_names, | |
load_from_cache_file=not data_args.overwrite_cache, | |
desc="Running tokenizer on validation dataset", | |
) | |
if training_args.do_predict: | |
max_target_length = data_args.val_max_target_length | |
if "test" not in dataset: | |
raise ValueError("--do_predict requires a test dataset") | |
predict_dataset = dataset["test"] | |
if data_args.max_predict_samples is not None: | |
max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples) | |
predict_dataset = predict_dataset.select(range(max_predict_samples)) | |
predict_dataset = predict_dataset.map( | |
preprocess_function, | |
batched=True, | |
num_proc=data_args.preprocessing_num_workers, | |
remove_columns=column_names, | |
load_from_cache_file=not data_args.overwrite_cache, | |
desc="Running tokenizer on prediction dataset", | |
) | |
def load_dataset_from_csv_file(file, prompt_column, answer_column): | |
max_target_length = data_args.val_max_target_length | |
this_dataset = load_dataset("csv", data_files=file) | |
return_dataset = this_dataset.map( | |
preprocess_function, | |
batched=True, | |
num_proc=data_args.preprocessing_num_workers, | |
remove_columns=column_names, | |
load_from_cache_file=not data_args.overwrite_cache, | |
desc="Running tokenizer on prediction dataset", | |
) | |
# Metric | |
metric = evaluate.load("rouge") | |
def postprocess_text(preds, labels): | |
preds = [pred.strip() for pred in preds] | |
labels = [label.strip() for label in labels] | |
# rougeLSum expects newline after each sentence | |
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] | |
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] | |
return preds, labels | |
def compute_metrics(preds, labels): | |
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) | |
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) | |
# Some simple post-processing | |
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) | |
result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True) | |
result = {k: round(v * 100, 4) for k, v in result.items()} | |
prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds] | |
result["gen_len"] = np.mean(prediction_lens) | |
return result | |
# Enable tensorboard only on the master node | |
has_tensorboard = is_tensorboard_available() | |
if has_tensorboard and jax.process_index() == 0: | |
try: | |
from flax.metrics.tensorboard import SummaryWriter | |
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir)) | |
except ImportError as ie: | |
has_tensorboard = False | |
logger.warning( | |
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" | |
) | |
else: | |
logger.warning( | |
"Unable to display metrics through TensorBoard because the package is not installed: " | |
"Please run pip install tensorboard to enable." | |
) | |
# Initialize our training | |
rng = jax.random.PRNGKey(training_args.seed) | |
rng, dropout_rng = jax.random.split(rng) | |
# Store some constant | |
num_epochs = int(training_args.num_train_epochs) | |
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() | |
batch_size_per_update = train_batch_size * training_args.per_device_gradient_accumulation_steps | |
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size) | |
eval_batch_size = per_device_eval_batch_size * jax.device_count() | |
try: | |
if train_batch_size > 0 and train_dataset!=None: | |
steps_per_epoch = len(train_dataset) // train_batch_size | |
else: | |
steps_per_epoch = 0 | |
except: | |
# does not exist | |
steps_per_epoch = 0 | |
total_train_steps = steps_per_epoch * num_epochs | |
# Create learning rate schedule | |
try: | |
linear_decay_lr_schedule_fn = create_learning_rate_fn( | |
len(train_dataset), | |
train_batch_size, | |
training_args.num_train_epochs, | |
training_args.warmup_steps, | |
training_args.learning_rate, | |
) | |
except: | |
print("") | |
# We use Optax's "masking" functionality to not apply weight decay | |
# to bias and LayerNorm scale parameters. decay_mask_fn returns a | |
# mask boolean with the same structure as the parameters. | |
# The mask is True for parameters that should be decayed. | |
def decay_mask_fn(params): | |
flat_params = traverse_util.flatten_dict(params) | |
# find out all LayerNorm parameters | |
layer_norm_candidates = ["layernorm", "layer_norm", "ln"] | |
layer_norm_named_params = set( | |
[ | |
layer[-2:] | |
for layer_norm_name in layer_norm_candidates | |
for layer in flat_params.keys() | |
if layer_norm_name in "".join(layer).lower() | |
] | |
) | |
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} | |
return traverse_util.unflatten_dict(flat_mask) | |
# create adam optimizer | |
try: | |
optimizer = optax.adamw( | |
learning_rate=linear_decay_lr_schedule_fn, | |
b1=training_args.adam_beta1, | |
b2=training_args.adam_beta2, | |
eps=training_args.adam_epsilon, | |
weight_decay=training_args.weight_decay, | |
mask=decay_mask_fn, | |
) | |
except: | |
print("") | |
# Setup train state | |
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng) | |
# label smoothed cross entropy | |
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0): | |
""" | |
The label smoothing implementation is adapted from Flax's official example: | |
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104 | |
""" | |
vocab_size = logits.shape[-1] | |
confidence = 1.0 - label_smoothing_factor | |
low_confidence = (1.0 - confidence) / (vocab_size - 1) | |
normalizing_constant = -( | |
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) | |
) | |
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence) | |
loss = optax.softmax_cross_entropy(logits, soft_labels) | |
loss = loss - normalizing_constant | |
# ignore padded tokens from loss | |
loss = loss * padding_mask | |
loss = loss.sum() | |
num_labels = padding_mask.sum() | |
return loss, num_labels | |
# Define eval fn | |
def eval_step(params, batch, label_smoothing_factor=0.0): | |
labels = batch.pop("labels") | |
logits = model(**batch, params=params, train=False)[0] | |
loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor) | |
num_labels = jax.lax.psum(num_labels, "batch") | |
# true loss = total loss / total samples | |
loss = jax.lax.psum(loss, "batch") | |
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss) | |
metrics = {"loss": loss} | |
return metrics | |
# Define generation function | |
max_length = ( | |
data_args.val_max_target_length if data_args.val_max_target_length is not None else model.config.max_length | |
) | |
num_beams = data_args.num_beams if data_args.num_beams is not None else model.config.num_beams | |
gen_kwargs = {"max_length": max_length, "num_beams": num_beams} | |
def generate_step(params, batch): | |
model.params = params | |
output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs) | |
return output_ids.sequences | |
def save_checkpoint(epoch, current_step): | |
''' | |
added by parapraxis on 12/5/22 | |
''' | |
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)) | |
model.save_pretrained(training_args.output_dir, params=params) | |
tokenizer.save_pretrained(training_args.output_dir) | |
if training_args.push_to_hub: | |
repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False) | |
def get_predictions_for_dataset(pred_dataset, prompt_field, correct_answer_field=None, save_csv_file=True): | |
''' | |
added by parapraxis on 12/5/22 | |
''' | |
logger.info("*** Predict ***") | |
pred_metrics = [] | |
pred_generations = [] | |
pred_labels = [] | |
# load the dataset | |
pred_loader = data_loader(input_rng, pred_dataset, eval_batch_size, drop_last=False) | |
pred_steps = math.ceil(len(pred_dataset) / eval_batch_size) | |
for _ in tqdm(range(pred_steps), desc="Predicting...", position=2, leave=False): | |
# Model forward | |
batch = next(pred_loader) | |
labels = batch["labels"] | |
metrics = pad_shard_unpad(p_eval_step, static_return=True)( | |
state.params, batch, min_device_batch=per_device_eval_batch_size | |
) | |
pred_metrics.append(metrics) | |
# generation | |
if data_args.predict_with_generate: | |
generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch) | |
pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"]))) | |
pred_labels.extend(labels) | |
# normalize prediction metrics | |
pred_metrics = get_metrics(pred_metrics) | |
pred_metrics = jax.tree_util.tree_map(jnp.mean, pred_metrics) | |
# compute ROUGE metrics | |
rouge_desc = "" | |
if data_args.predict_with_generate: | |
rouge_metrics = compute_metrics(pred_generations, pred_labels) | |
pred_metrics.update(rouge_metrics) | |
rouge_desc = " ".join([f"Predict {key}: {value} |" for key, value in rouge_metrics.items()]) | |
# Print metrics | |
desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})" | |
logger.info(desc) | |
# save final metrics in json | |
if jax.process_index() == 0: | |
rouge_metrics = {f"test_{metric_name}": value for metric_name, value in rouge_metrics.items()} | |
path = os.path.join(training_args.output_dir, "test_results.json") | |
with open(path, "w") as f: | |
json.dump(rouge_metrics, f, indent=4, sort_keys=True) | |
# added by parapraxis on 12/5/22 | |
# save predictions in a new file that is the eval file with the predictions added | |
if data_args.predict_with_generate: | |
if jax.process_index() == 0: | |
# add _results to the file name | |
fn_without_ext = os.path.splitext(data_args.test_file)[0] | |
if save_csv_file: | |
output_predict_file = fn_without_ext + f"_results_epoch_{epoch}.csv" | |
output_predict_file = os.path.join(training_args.output_dir, output_predict_file) | |
test_df = pd.read_csv(data_args.test_file) | |
# add decoded predictions using the tokenizer to test_df | |
test_df["predictions"] = tokenizer.batch_decode(pred_generations, skip_special_tokens=True) | |
if correct_answer_field is not None: | |
# compare predictions to "correct_answer" column after stripping spaces from the ends | |
test_df[correct_answer_field] = test_df[correct_answer_field].str.strip() | |
test_df["predictions"] = test_df["predictions"].str.strip() | |
test_df["correct"] = test_df["predictions"] == test_df[correct_answer_field] | |
# sum the correct items | |
num_correct = test_df["correct"].sum() | |
# get the total number of items | |
num_total = len(test_df) | |
# calculate the accuracy | |
accuracy = num_correct / num_total | |
if save_csv_file: | |
# add number of correct items to the file name | |
output_predict_file = fn_without_ext + f"_results_epoch_{epoch}_correct_{num_correct}.csv" | |
print(f"Saving predictions to {output_predict_file}") | |
print(f"Accuracy: {accuracy}") | |
print(f"Number correct: {num_correct}") | |
print(f"Total number: {num_total}") | |
if save_csv_file: | |
test_df.to_csv(output_predict_file, index=False) | |
else: | |
return test_df | |
def df_to_dataset(df): | |
print("loading dataset from pandas dataframe") | |
dataset = Dataset.from_pandas(df) | |
dataset = dataset.map(preprocess_function, batched=True) | |
return dataset | |
def get_predictions_for_df(data_frame, prompt_column_title, correct_answer_column_title, save_csv_file=True): | |
dataset_file = df_to_dataset(data_frame) | |
dataset_to_evaluate = load_dataset_from_csv_file(dataset_file) | |
get_predictions_for_dataset(dataset_to_evaluate, prompt_column_title, correct_answer_column_title, save_csv_file=save_csv_file) | |
def get_predictions_for_file(dataset_file, prompt_column_title, correct_answer_column_title, save_csv_file=True): | |
dataset_to_evaluate = load_dataset_from_csv_file(dataset_file) | |
get_predictions_for_dataset(dataset_to_evaluate, prompt_column_title, correct_answer_column_title, save_csv_file=save_csv_file) | |
def filter_file_for_max_tokens_and_add_tasks(model, tokenizer, file_name, max_input_length, max_output_length, experiment_folder, write_to_file=None, | |
filter_by_token_count=True, shuffle=True, add_classification_task=False, add_correctness_task_ratio=0, add_fix_task_ratio=0): | |
""" | |
This function takes a file_name and filters it for max_tokens and adds additional tasks if requested | |
Explanation of parameters: | |
file_name: the name of the dataset file to be filtered | |
write_to_file: if not empty, the filtered dataset will be written to this file | |
shuffle: if True, the dataset will be shuffled | |
add_classification_task: if True, the dataset will be doubled and the second half will be a classification task | |
add_correctness_task_ratio: if > 0, the dataset will be sampled by this ratio and the sampled data will be used to create a correctness task (i.e., is the answer correct?) | |
add_fix_task_ratio: if > 0, the dataset will be sampled by this ratio and the sampled data will be used to create a fix task (i.e., fix the answer) | |
""" | |
if "http" in file_name: | |
tdf = file_name | |
else: | |
if "/" not in file_name: | |
tdf = f"{experiment_folder}{file_name}" | |
else: | |
tdf = file_name | |
print(f"filter_file_for_max_tokens {tdf}") | |
df = pd.read_csv(tdf) | |
if "additional_tasks" not in file_name: # don't run if already added tasks | |
if add_correctness_task_ratio > 0: | |
print("adding correctness task") | |
# sample the add_correctness_task_ratio of the data | |
df_correctness_sample = df.sample(frac=add_correctness_task_ratio, random_state=42) | |
# run inference on the sampled data | |
df_correctness_sample['predictions'] = get_predictions_for_df(model, df_correctness_sample, save_csv_file=False) | |
print(f"inference complete for correctness task adding {len(df_correctness_sample)} items") | |
# add the correctness task 'prompt' plus 'predictions' and change correct_answer to yes if prediction==correct_answer and no if prediction!=correct_answer | |
# make a copy of 'prompt' to prompt_old | |
df_correctness_sample['prompt_old'] = df_correctness_sample['prompt'] | |
df_correctness_sample['prompt'] = "correct?:" + df_correctness_sample['prompt'] + " " + df_correctness_sample['predictions'] | |
# copy 'correct_answer' to 'correct_answer_board' | |
df_correctness_sample['correct_answer_board'] = df_correctness_sample['correct_answer'] | |
df_correctness_sample['correct_answer'] = df_correctness_sample.apply(lambda row: 'yes' if row['predictions'] == row['correct_answer'] else 'no', axis=1) | |
# print head | |
print("examples of correctness task") | |
print(df_correctness_sample.head()) | |
if add_fix_task_ratio > 0: | |
print("adding fix task") | |
if add_correctness_task_ratio > 0: | |
# then use the same sample for the fix task | |
df_fix_sample = df_correctness_sample | |
# use prompt_old for the fix task and correct_answer_board | |
# copy prompt_old to prompt and predictions to correct_answer | |
df_fix_sample['prompt'] = 'fix:' + df_fix_sample['prompt_old'] | |
df_fix_sample['correct_answer'] = df_fix_sample['predictions'] | |
else: | |
# sample the add_fix_task_ratio of the data | |
df_fix_sample = df.sample(frac=add_fix_task_ratio, random_state=42) | |
# run inference on the sampled data | |
print(f"running inference for fix task on {len(df_fix_sample)} items") | |
df_fix_sample['predictions'] = get_predictions_for_df(model, df_fix_sample) | |
df_fix_sample['prompt'] = 'fix:' + df_fix_sample['prompt'] | |
df_fix_sample['correct_answer'] = df_fix_sample['predictions'] | |
print(f"inference complete for fix task adding {len(df_fix_sample)} items") | |
print(df_fix_sample.head()) | |
if add_classification_task: | |
print("adding classification task - doubles the number of training items") | |
# create new column id_int that is 'id' without the .json extension and converted to int from hex | |
df['id_root_file'] = df['path_data'].apply(lambda x: str(x.split(".")[0].split("_")[0])) | |
df['id_int'] = df.groupby('id_root_file').ngroup() | |
# convert to string | |
df['id_int'] = df['id_int'].apply(lambda x: str(x)) | |
# iterate over the rows and add the classification task; add "solve:" to the prompt and duplicate the item adding classify to the duplicate | |
# duplicate each row and add "classify:" to the prompt (use concat) | |
df = pd.concat([df, df]) | |
# reset the index | |
df = df.reset_index(drop=True) | |
# add "classify:" to the prompt (the newly added items) in the second half of the dataframe | |
df.loc[df.index >= len(df)/2, 'prompt'] = "classify: " + df.loc[df.index >= len(df)/2, 'prompt'].astype(str) | |
# add "solve:" to the prompt (the original items) in the first half of the dataframe | |
df.loc[df.index < len(df)/2, 'prompt'] = "solve: " + df.loc[df.index < len(df)/2, 'prompt'].astype(str) | |
# for correct_answer in the second half, set it to df['path_data'] + the next to last string when the id string is split on "/" | |
df.loc[df.index >= len(df)/2, 'correct_answer'] = df.loc[df.index >= len(df)/2, 'id'].str.split("/").str[-2] + "/" + df.loc[df.index >= len(df)/2, 'path_data'] + ' ' + df.loc[df.index >= len(df)/2, 'id_int'] # put the id_int at the end of the correct_answer, because this could change if the dataset changes | |
# need to add these after the classification task, because the classify task adds numbers related to the line number | |
if add_correctness_task_ratio > 0: | |
df.extend(df_correctness_sample) | |
if add_fix_task_ratio > 0: | |
df.extend(df_fix_sample) | |
if add_classification_task or add_correctness_task_ratio > 0 or add_fix_task_ratio > 0: | |
# write to csv file | |
if write_to_file == None: | |
df.to_csv(f"{experiment_folder}additional_tasks_{file_name}", index=False) | |
else: | |
df.to_csv(f"{experiment_folder}additional_tasks_{write_to_file}", index=False) | |
if filter_by_token_count: | |
print(f"filtering dataset by token count") | |
# add the token count columns | |
#tokenized_prompts = tokenizer(df['prompt'].tolist(), padding=True, truncation=True, max_length=max_input_length, return_tensors="pt") | |
#df['token_count'] = df['prompt'].apply(lambda x: len(tokenized_prompts['input_ids'])) | |
df['token_count'] = df['prompt'].apply(lambda x: len(tokenizer.encode(x))) | |
# same for correct answer | |
df['token_count_correct'] = df['correct_answer'].apply(lambda x: len(tokenizer.encode(x))) | |
# filter out the examples that are too long | |
df = df[df['token_count'] <= max_input_length] | |
df = df[df['token_count_correct'] <= max_output_length] | |
# fix the index | |
df = df.reset_index(drop=True) | |
# randomize the dataset order | |
if shuffle: | |
# shuffle the dataset and do so in a reproducible way | |
print("shuffling dataset") | |
df = df.sample(frac=1, random_state=42).reset_index(drop=True) | |
#df = df.sample(frac=1).reset_index(drop=True) | |
return df | |
# Define gradient update step fn | |
def train_step(state, batch, label_smoothing_factor=0.0): | |
# only one single rng per grad step, with or without accumulation, as the graph should be identical over one effective training batch | |
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) | |
def compute_loss(params, minibatch): | |
labels = minibatch.pop("labels") | |
logits = state.apply_fn( | |
**minibatch, | |
params=params, | |
dropout_rng=dropout_rng, | |
# freeze_feature_encoder=model_args.freeze_feature_encoder, | |
train=True, | |
)[0] | |
loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor) | |
return loss, num_labels | |
grad_fn = jax.value_and_grad(compute_loss, has_aux=True) | |
if training_args.per_device_gradient_accumulation_steps == 1: | |
(loss, num_labels), grad = grad_fn(to_dtype(state.params), batch) | |
# Custom gradient accumulation | |
else: | |
# print("batch = ", batch) | |
# See https://github.com/huggingface/transformers/issues/20855 | |
# add a first dimension over gradient_accumulation_steps for minibatch slices | |
batch = jax.tree_map( | |
lambda x: x.reshape( | |
training_args.per_device_train_batch_size, training_args.per_device_gradient_accumulation_steps, -1 #*x.shape[1::] | |
), | |
batch, | |
) | |
def accum_minibatch_step(accum_grad, minibatch): | |
# compute loss, num labels and grad over minibatch and accumulate | |
(loss, num_labels), grad = grad_fn(to_dtype(state.params), minibatch) | |
return jax.tree_map(jnp.add, accum_grad, grad), (loss, num_labels) | |
# create an initial state for accumulating losses, num labels and gradients | |
init_grad = jax.tree_map(jnp.zeros_like, to_dtype(state.params)) | |
# loop accum minibatch step over the number of gradient accumulation steps | |
grad, (loss, num_labels) = jax.lax.scan(accum_minibatch_step, init_grad, batch) | |
grad = jax.lax.psum(grad, "batch") | |
loss = jax.lax.psum(loss.sum(), "batch") | |
total_samples = jax.lax.psum(num_labels.sum(), "batch") | |
grad = jax.tree_map(lambda g: g / total_samples, grad) | |
loss = jax.tree_map(lambda l: l / total_samples, loss) | |
# update state | |
new_state = state.apply_gradients( | |
grads=grad, | |
dropout_rng=new_dropout_rng, | |
# to_dtype=to_dtype, | |
) | |
# compute gradient norms over all layers, total encoder, total decoder and global for detailed monitoring | |
layer_grad_norm = jax.tree_map(jnp.linalg.norm, grad) | |
''' | |
logs = { | |
"layer_grad_norm": layer_grad_norm, | |
"encoder_grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm["encoder"])), | |
"decoder_grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm["decoder"])), | |
} | |
logs["grad_norm"] = jnp.linalg.norm([logs["encoder_grad_norm"], logs["decoder_grad_norm"]]) | |
# compute parameter norms over all layers, total encoder, total decoder and global for detailed monitoring | |
layer_param_norm = jax.tree_map(jnp.linalg.norm, new_state.params) | |
logs["layer_param_norm"] = layer_param_norm | |
logs["encoder_param_norm"] = jnp.linalg.norm(jax.tree_util.tree_leaves(layer_param_norm["encoder"])) | |
logs["decoder_param_norm"] = jnp.linalg.norm(jax.tree_util.tree_leaves(layer_param_norm["decoder"])) | |
logs["param_norm"] = jnp.linalg.norm([logs["encoder_param_norm"], logs["decoder_param_norm"]]) | |
''' | |
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} | |
#metrics.update(logs) | |
metrics = jax.lax.pmean(metrics, axis_name="batch") | |
# metrics = to_fp32(metrics) | |
return new_state, metrics | |
# Create parallel version of the train and eval step | |
p_train_step = jax.pmap( | |
partial(train_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch", donate_argnums=(0,) | |
) | |
p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch") | |
p_generate_step = jax.pmap(generate_step, "batch") | |
# Replicate the train state on each device | |
state = state.replicate() | |
logger.info("***** Running training *****") | |
logger.info(f" Num examples = {len(train_dataset)}") | |
logger.info(f" Num Epochs = {num_epochs}") | |
logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}") | |
# buttercutter added 12/6/22 | |
logger.info(f" Total train batch size (w. parallel, distributed & gradient accumulation) = {batch_size_per_update}") | |
# logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}") | |
logger.info(f" Total optimization steps = {total_train_steps}") | |
train_time = 0 | |
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) | |
for epoch in epochs: | |
# ======================== Training ================================ | |
train_start = time.time() | |
# Create sampling rng | |
rng, input_rng = jax.random.split(rng) | |
train_metrics = [] | |
# Generate an epoch by shuffling sampling indices from the train dataset | |
train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True) | |
steps_per_epoch = len(train_dataset) // train_batch_size | |
# train | |
for cstep in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): | |
batch = next(train_loader) | |
batch = shard(batch) | |
state, train_metric = p_train_step(state, batch) | |
train_metrics.append(train_metric) | |
# added by parapraxis on 12/5/22 | |
if cstep % training_args.logging_steps == 0: | |
# print data from train_metric | |
print("train_metric", train_metric) | |
if cstep % training_args.save_steps == 0: | |
save_checkpoint(epoch, state.step) | |
train_time += time.time() - train_start | |
# this was failing so I changed train_metric to train_metrics - parapraxis 11/30/22 | |
train_metric = unreplicate(train_metric) | |
# added by parapraxis on 12/5/22 | |
save_checkpoint(epoch, steps_per_epoch * (epoch + 1)) | |
epochs.write( | |
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate:" | |
f" {train_metric['learning_rate']})" | |
) | |
# ======================== Evaluating ============================== | |
eval_metrics = [] | |
eval_preds = [] | |
eval_labels = [] | |
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, drop_last=False) | |
eval_steps = math.ceil(len(eval_dataset) / eval_batch_size) | |
for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False): | |
# Model forward | |
batch = next(eval_loader) | |
labels = batch["labels"] | |
metrics = pad_shard_unpad(p_eval_step, static_return=True)( | |
state.params, batch, min_device_batch=per_device_eval_batch_size | |
) | |
eval_metrics.append(metrics) | |
# generation | |
if data_args.predict_with_generate: | |
generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch) | |
eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"]))) | |
eval_labels.extend(labels) | |
# normalize eval metrics | |
eval_metrics = get_metrics(eval_metrics) | |
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics) | |
# compute ROUGE metrics | |
rouge_desc = "" | |
if data_args.predict_with_generate: | |
rouge_metrics = compute_metrics(eval_preds, eval_labels) | |
eval_metrics.update(rouge_metrics) | |
rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()]) | |
# Print metrics and update progress bar | |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})" | |
epochs.write(desc) | |
epochs.desc = desc | |
# Save metrics | |
if has_tensorboard and jax.process_index() == 0: | |
cur_step = epoch * (len(train_dataset) // train_batch_size) | |
write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step) | |
# prediction loop added to the epoch loop by parapraxis on 12/5/22 | |
# ======================== Prediction loop ============================== | |
if training_args.do_predict: | |
get_predictions_for_dataset(predict_dataset, text_column, summary_column) | |
if __name__ == "__main__": | |
main() | |
# # normal operation | |
# if __name__ == "__main__": | |
# main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment