Last active
January 21, 2024 13:26
-
-
Save sutyum/7c2e89d0ceaac4032922c8330752a36f to your computer and use it in GitHub Desktop.
Upload a merged model in Huggingface
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import yaml | |
from huggingface_hub import ModelCard, HfApi | |
from jinja2 import Template | |
# Argument parser setup | |
parser = argparse.ArgumentParser(description='Create a repo on Huggingface for the model and upload it.') | |
parser.add_argument('model_name', type=str, help='Model name') | |
parser.add_argument('--new', action='store_true', default=True, help='Create new repo') | |
parser.add_argument('-y', '--yaml_path', type=str, required=True, help='Path to YAML configuration file') | |
parser.add_argument('-u', '--username', type=str, default="Technoculture", help='Hugging Face username') | |
parser.add_argument('-t', '--token', type=str, required=True, help='Hugging Face API token') | |
parser.add_argument('-m', '--is_moe', type=bool, default=False, help='Hugging Face API token') | |
args = parser.parse_args() | |
# Read YAML configuration file | |
with open(args.yaml_path, 'r') as file: | |
yaml_config = file.read() | |
data = yaml.safe_load(yaml_config) | |
print(data) | |
# Branch selection and template setup | |
is_moe = args.is_moe | |
if not is_moe: | |
template_text = """ | |
--- | |
license: apache-2.0 | |
tags: | |
- merge | |
- mergekit | |
{%- for model in models %} | |
- {{ model }} | |
{%- endfor %} | |
--- | |
# {{ model_name }} | |
{{ model_name }} is a merge of the following models: | |
{%- for model in models %} | |
* [{{ model }}](https://huggingface.co/{{ model }}) | |
{%- endfor %} | |
## 🧩 Configuration | |
```yaml | |
{{- yaml_config -}} | |
``` | |
## 💻 Usage | |
```python | |
!pip install -qU transformers accelerate | |
from transformers import AutoTokenizer | |
import transformers | |
import torch | |
model = "{{ username }}/{{ model_name }}" | |
messages = [{"role": "user", "content": "I am feeling sleepy these days"}] | |
tokenizer = AutoTokenizer.from_pretrained(model) | |
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
pipeline = transformers.pipeline( | |
"text-generation", | |
model=model, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
) | |
outputs = pipeline(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95) | |
print(outputs[0]["generated_text"]) | |
``` | |
""" | |
# Create a Jinja template object | |
jinja_template = Template(template_text.strip()) | |
# Get list of models from config | |
data = yaml.safe_load(yaml_config) | |
if "models" in data: | |
models = [data["models"][i]["model"] for i in range(len(data["models"])) if "parameters" in data["models"][i]] | |
elif "parameters" in data: | |
models = [data["slices"][0]["sources"][i]["model"] for i in range(len(data["slices"][0]["sources"]))] | |
elif "slices" in data: | |
models = [data["slices"][i]["sources"][0]["model"] for i in range(len(data["slices"]))] | |
else: | |
raise Exception("No models or slices found in yaml config") | |
# Fill the template | |
content = jinja_template.render( | |
model_name=args.model_name, | |
models=models, | |
yaml_config=yaml_config, | |
username=args.username, | |
) | |
else: | |
template_text = """ | |
--- | |
license: apache-2.0 | |
tags: | |
- moe | |
- merge | |
{%- for model in models %} | |
- {{ model }} | |
{%- endfor %} | |
--- | |
# {{ model_name }} | |
{{ model_name }} is a Mixure of Experts (MoE) made with the following models: | |
{%- for model in models %} | |
* [{{ model }}](https://huggingface.co/{{ model }}) | |
{%- endfor %} | |
## 🧩 Configuration | |
```yaml | |
{{- yaml_config -}} | |
``` | |
## 💻 Usage | |
```python | |
!pip install -qU transformers bitsandbytes accelerate | |
from transformers import AutoTokenizer | |
import transformers | |
import torch | |
model = "{{ username }}/{{ model_name }}" | |
tokenizer = AutoTokenizer.from_pretrained(model) | |
pipeline = transformers.pipeline( | |
"text-generation", | |
model=model, | |
model_kwargs={"torch_dtype": torch.float16, "load_in_4bit": True}, | |
) | |
messages = [{"role": "user", "content": "Explain what a Mixture of Experts is in less than 100 words."}] | |
prompt = pipeline.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
outputs = pipeline(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95) | |
print(outputs[0]["generated_text"]) | |
``` | |
""" | |
# Create a Jinja template object | |
jinja_template = Template(template_text.strip()) | |
# Fill the template | |
models = [model['source_model'] for model in data['experts']] | |
content = jinja_template.render( | |
model_name=args.model_name, | |
models=models, | |
yaml_config=yaml_config, | |
username=args.username, | |
) | |
# Save the model card | |
card = ModelCard(content) | |
card.save('merge/README.md') | |
# Defined in the secrets tab in Google Colab | |
api = HfApi(token=args.token) | |
if args.new: | |
api.create_repo( | |
repo_id=f"{args.username}/{args.model_name}", | |
repo_type="model" | |
) | |
# Upload merge folder | |
api.upload_folder( | |
repo_id=f"{args.username}/{args.model_name}", | |
folder_path="merge", | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment