Skip to content

Instantly share code, notes, and snippets.

@wannaphong
Last active June 14, 2024 10:00
Show Gist options
  • Save wannaphong/2b73c33f898816b18ead38b176e6a5d0 to your computer and use it in GitHub Desktop.
Save wannaphong/2b73c33f898816b18ead38b176e6a5d0 to your computer and use it in GitHub Desktop.
from easydel import (
TrainArguments,
AutoEasyDeLModelForCausalLM,
EasyDeLOptimizers,
EasyDeLSchedulers,
EasyDeLGradientCheckPointers,
SFTTrainer,
conversations_formatting_function # i have added this one for newcomers so if they
# don't know what's going on they can use this pre created prompter
)
from datasets import load_dataset
import flax
from jax import numpy as jnp
from transformers import AutoTokenizer
from jax.sharding import PartitionSpec
import easydel as ed
import jax
sharding_axis_dims = (1, -1, 1, 1)
max_length = 2048
input_shape = (32, max_length) # since your using TPUv4-64
huggingface_repo_id_or_path = "numfa/open_llama_3b_thai"
dtype = jnp.bfloat16
block_size = 512
attn_mechanism = "sharded_vanilla" # or flash or sharded_vanilla ...
partition_axis = ed.PartitionAxis()
model, params = ed.AutoEasyDeLModelForCausalLM.from_pretrained(
huggingface_repo_id_or_path,
device = jax.devices('cpu')[0],
input_shape = input_shape,
device_map = "auto",
auto_shard_params=False,
sharding_axis_dims = sharding_axis_dims,
verbose_params=True,
config_kwargs=dict(
use_scan_mlp=False,
attn_mechanism=attn_mechanism,
partition_axis=partition_axis
),
partition_axis=partition_axis
)
config = model.config
rules = (
('model/embed_tokens/embedding', PartitionSpec("tp",('fsdp', 'sp'),)),
('self_attn/(q_proj|k_proj|v_proj)/kernel', PartitionSpec(('fsdp', 'sp'),"tp")),
('self_attn/o_proj/kernel', PartitionSpec(('fsdp', 'sp'),"tp")),
('mlp/gate_proj/kernel', PartitionSpec(('fsdp', 'sp'),"tp")),
('mlp/down_proj/kernel', PartitionSpec(('fsdp', 'sp'),"tp")),
('mlp/up_proj/kernel', PartitionSpec(('fsdp', 'sp'),"tp")),
('input_layernorm/kernel', PartitionSpec(None,)),
('post_attention_layernorm/kernel', PartitionSpec(None,)),
('model/norm/kernel', PartitionSpec(None,)),
('lm_head/kernel', PartitionSpec(('fsdp', 'sp'),"tp")),
('.*', PartitionSpec(('fsdp', 'sp'),))
)
config.get_partition_rules = lambda _: rules
config.add_basic_configurations(
attn_mechanism=attn_mechanism,
shard_attention_computation=True,
)
tokenizer = AutoTokenizer.from_pretrained(
huggingface_repo_id_or_path,
trust_remote_code=True
)
if tokenizer.pad_token == None:
tokenizer.pad_token = tokenizer.eos_token
configs_to_initialize_model_class = {
"config": model.config,
"dtype": jnp.bfloat16,
"param_dtype": jnp.bfloat16,
"input_shape": input_shape
}
train_arguments = TrainArguments(
model_class=type(model),
model_name="ol-sft",
custom_rule=config.get_partition_rules(True), # here you use custom partition_rule for model.
num_train_epochs=3,
configs_to_initialize_model_class=configs_to_initialize_model_class,
learning_rate=5e-5,
learning_rate_end=1e-6,
optimizer=EasyDeLOptimizers.ADAMW,
scheduler=EasyDeLSchedulers.WARM_UP_COSINE,
weight_decay=0.01,
total_batch_size=64,
max_training_steps=None, # None to let trainer Decide
do_train=True,
do_eval=False, # it's optional but supported
backend="tpu", # default backed is set to cpu, so you must define you want to use tpu cpu or gpu
max_sequence_length=max_length, # Note that you have to change this in the model config too
gradient_checkpointing=EasyDeLGradientCheckPointers.NOTHING_SAVEABLE,
sharding_array=sharding_axis_dims, # the way to shard model across gpu,cpu or TPUs using sharding array (1, -1, 1, 1)
# everything training will be in sequence and model parallel automatic and share data between devices
remove_ckpt_after_load=True,
gradient_accumulation_steps=8,
loss_re_mat="",
dtype=jnp.bfloat16,
do_shard_fns=True,
use_wandb=True,
track_memory=False # Install GO lang and set this to true if you want track memory
)
def prompter(sample):
return [conversations_formatting_function(tokenizer, messages_field="messages")(sample)]
train_dataset = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft")
trainer = SFTTrainer(
arguments=train_arguments,
train_dataset=train_dataset,
eval_dataset=None, # we don't have eval dataset rn :)
tokenizer=tokenizer,
dataset_text_field=None,
formatting_func=prompter,
packing=True,
num_of_sequences=max_length,
)
output = trainer.train(flax.core.FrozenDict({"params": params}))
print(f"Hey ! , here's where your model saved {output.checkpoint_path}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment