Skip to content

Instantly share code, notes, and snippets.

@hushell
Last active January 30, 2024 14:09
Show Gist options
  • Save hushell/2893dcde5dc019449214afdb62d5958f to your computer and use it in GitHub Desktop.
Save hushell/2893dcde5dc019449214afdb62d5958f to your computer and use it in GitHub Desktop.
tinyllama_mole
'''
python test.py recipes/tinyllama_mole/sft/config_routeraux_ep3.yaml
'''
import logging
import random
import sys
import datasets
import torch
import transformers
from transformers import set_seed
from trl import SFTTrainer
from configs import (
DataArguments,
H4ArgumentParser,
ModelArguments,
SFTConfig,
)
from model_utils import (
get_checkpoint,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
get_tokenizer,
)
from data import (
apply_chat_template,
get_datasets,
)
parser = H4ArgumentParser((ModelArguments, DataArguments, SFTConfig))
model_args, data_args, training_args = parser.parse()
model_kwargs = dict(
revision='main',
trust_remote_code=True,
use_flash_attention_2=True,
torch_dtype='bfloat16',
use_cache=False,
device_map=None,
quantization_config=None,
output_router_logits=True,
router_aux_loss_coef=0.05,
)
model_name_or_path = 'ondevicellm/tinyllama_mole_sft_routeraux_ultrachat_ep3'
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, **model_kwargs)
tokenizer = get_tokenizer(model_args, data_args)
prompt = "Hey, are you conscious? Can you talk to me?"
inputs = tokenizer(prompt, return_tensors="pt")
'''
return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
router_logits=outputs.router_logits,
)
'''
outputs = model.forward(inputs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment