Skip to content

Instantly share code, notes, and snippets.

@ArthurZucker
Created November 10, 2023 15:05
Show Gist options
  • Save ArthurZucker/159dedfcb908467e5f484cf1c143155e to your computer and use it in GitHub Desktop.
Save ArthurZucker/159dedfcb908467e5f484cf1c143155e to your computer and use it in GitHub Desktop.
Script to automatically convert and upload marian models, checking new results vs previous
#!/bin/bash
# conda create -n 4.29 python==3.9
# source activate 4.29
# pip install transformers==4.29.2
# pip install torch accelerate sentencepiece tokenizers colorama sacremoses googletrans
# conda create -n 4.34 python==3.9
# source activate 4.34
# pip install transformers==4.34
# pip install torch accelerate sentencepiece tokenizers colorama sacremoses googletrans
# Define color codes
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m' # No Color
# Define emojis
SMILEY='😊'
THUMBS_UP='👍'
WARNING='⚠️'
# Valid languages
LANGUAGES=(
'af' 'sq' 'am' 'ar' 'hy' 'az' 'eu' 'be' 'bn' 'bs' 'bg' 'ca' 'ceb' 'ny' 'zh-cn' 'zh-tw' 'co' 'hr' 'cs' 'da' 'nl' 'en' 'eo' 'et' 'tl' 'fi' 'fr' 'fy' 'gl' 'ka' 'de' 'el' 'gu' 'ht' 'ha' 'haw' 'iw' 'he' 'hi' 'hmn' 'hu' 'is' 'ig' 'id' 'ga' 'it' 'ja' 'jw' 'kn' 'kk' 'km' 'ko' 'ku' 'ky' 'lo' 'la' 'lv' 'lt' 'lb' 'mk' 'mg' 'ms' 'ml' 'mt' 'mi' 'mr' 'mn' 'my' 'ne' 'no' 'or' 'ps' 'fa' 'pl' 'pt' 'pa' 'ro' 'ru' 'sm' 'gd' 'sr' 'st' 'sn' 'sd' 'si' 'sk' 'sl' 'so' 'es' 'su' 'sw' 'sv' 'tg' 'ta' 'te' 'th' 'tr' 'uk' 'ur' 'ug' 'uz' 'vi' 'cy' 'xh' 'yi' 'yo' 'zu'
)
# Task 1: Get list of model_ids
model_ids=$(python -c "
from huggingface_hub import HfApi
hf_api = HfApi()
models = hf_api.list_models(author='Helsinki-NLP', filter='marian')
model_ids = [model.id for model in models]
print('\n'.join(model_ids))
")
model_ids=$(python -c "
import requests
params = {
'author': 'Helsinki-NLP',
'other': 'marian',
'expand[]': 'downloadsAllTime',
}
response = requests.get('https://huggingface.co/api/models', params=params)
models = response.json()
model_list = sorted(models, key=lambda e: e['downloadsAllTime'], reverse=True)
model_list = [model['id'] for model in model_list if 'big' not in model['id']]
print('\n'.join(model_list))
")
CONDA_PATH=$(conda info --base)
DEVICE="'cuda'"
# Loop through model_ids
for model_id in $model_ids; do
echo -e " - $model_id"
source_lang=$(echo "$model_id" | awk -F'-' '{print $(NF-1)}')
# Check if source_lang length is greater than 3
if [ ${#source_lang} -gt 3 ]; then
echo -e "${YELLOW}${WARNING} Skipping: $model_id with $source_lang${NC}"
continue
fi
# Check if source_lang is in the list of valid languages
if [[ " ${LANGUAGES[@]} " =~ " $source_lang " ]]; then
# Attempt translation with try-except block
translation=$($CONDA_PATH/envs/py39/bin/python - <<EOF
from googletrans import Translator
from logging import captureWarnings, getLogger, ERROR
logger = getLogger('py.warnings')
logger.setLevel(ERROR)
captureWarnings(True)
translator = Translator()
source_lang = "$source_lang"
try:
translation = translator.translate("Hey! Let\'s learn together", dest=source_lang)
from colorama import Fore, Back, Style
print(translation.text)
except Exception as e:
print(f"Error translating {source_lang}: {str(e)}")
translation = None
EOF
)
echo -e "${GREEN}${SMILEY}Testing input $model_id with input sentence from $source_lang: $translation ${NC}"
else
translation="' >>en<< Hey how are you?'"
echo -e "${YELLOW} \tFailed to translate to $source_lang, using the default english prompt $translation ${NC}"
fi
captured_output="''"
# Task 3: Run scripts with MarianMTModel
formatted_model_name=$(echo $model_id | tr '/' '_')
output_file="${formatted_model_name}_${conda_env}.pt"
python_path="$CONDA_PATH/envs/4.29/bin/python"
captured_output=$(TRANSFORMERS_VERBOSITY=error $python_path - <<EOF
from transformers import AutoTokenizer, MarianMTModel
from logging import captureWarnings, getLogger, ERROR
import torch, os, transformers
logger = getLogger('py.warnings')
logger.setLevel(ERROR)
captureWarnings(True)
transformers.utils.logging.set_verbosity_error()
tokenizer = AutoTokenizer.from_pretrained("$model_id")
inputs = tokenizer("$translation", return_tensors="pt", padding=True).to($DEVICE)
model = MarianMTModel.from_pretrained("$model_id").to($DEVICE)
torch.save(model(**inputs, decoder_input_ids = inputs["input_ids"]).logits.detach(),'Arthur/$output_file')
print(tokenizer.batch_decode(model.generate(**inputs)))
EOF
)
commit_description="""Following the merge of [a PR](https://github.com/huggingface/transformers/pull/24310) in \`transformers\` it appeared that \
this model was not properly converted. This PR will fix the inference and was tested using the following script:
\`\`\`python
>>> from transformers import AutoTokenizer, MarianMTModel
>>> tokenizer = AutoTokenizer.from_pretrained('$model_id')
>>> inputs = tokenizer(\"$translation\", return_tensors=\"pt\", padding=True)
>>> model = MarianMTModel.from_pretrained('$model_id')
>>> print(tokenizer.batch_decode(model.generate(**inputs)))
"$captured_output"
\`\`\`
"""
echo -e "${YELLOW}🤗 transformers == 4.29.1: $captured_output ${NC}"
python_path="$CONDA_PATH/envs/py39/bin/python"
TRANSFORMERS_VERBOSITY=error $python_path - <<EOF
from transformers import AutoTokenizer, MarianMTModel, MarianModel
from logging import captureWarnings, getLogger, ERROR
import torch, os, transformers
from huggingface_hub import HfApi
api = HfApi()
logger = getLogger('py.warnings')
logger.setLevel(ERROR)
captureWarnings(True)
transformers.utils.logging.set_verbosity_error()
tokenizer = AutoTokenizer.from_pretrained("$model_id")
inputs = tokenizer("$translation", return_tensors="pt", padding=True).to($DEVICE)
model = MarianMTModel.from_pretrained("$model_id", torch_dtype="auto").to($DEVICE)
logits = model(**inputs, decoder_input_ids = inputs["input_ids"]).logits
torch.save(logits.detach(),'Arthur/$output_file')
from colorama import Fore, Back, Style
translated = model.generate(**inputs)
color = Fore.GREEN if $captured_output == tokenizer.batch_decode(translated) else Fore.RED
print(color + "🤗 transformers == 4.34-before:\t", tokenizer.batch_decode(translated))
model_base = MarianModel.from_pretrained("$model_id", torch_dtype="auto").to($DEVICE)
model.model = model_base
model.lm_head.weight.data = model.model.shared.weight.data
translated_fixed = model.generate(**inputs)
color = Fore.GREEN if $captured_output == tokenizer.batch_decode(translated_fixed) else Fore.RED
print(color + "🤗 transformers == 4.34-fixed:\t", tokenizer.batch_decode(translated_fixed))
print(Style.RESET_ALL)
if tokenizer.batch_decode(translated) != tokenizer.batch_decode(translated_fixed) and len(translated_fixed)<40:
print("writing $model_id to a file")
del model.config.transformers_version
del model.config._name_or_path
commit_details = model.push_to_hub("$model_id", create_pr = True, commit_message="Update checkpoint for transformers>=4.29", commit_description="""${commit_description}""")
print(commit_details.pr_url)
api.merge_pull_request("$model_id", discussion_num = commit_details.pr_num, comment = "Automatically merging the PR.")
with open("Arthur/files_to_update.txt", 'a+') as f:
f.write(f"$model_id\n")
EOF
# Clear Hugging Face cache
rm -rf /Users/arthur/.cache/huggingface/hub/models--Helsinki-NLP--opus-mt-*
done
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment