Skip to content

Instantly share code, notes, and snippets.

@EvilFreelancer
Last active February 3, 2024 13:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save EvilFreelancer/3f5166ff9b6de1e4adccf6ea192ab9ca to your computer and use it in GitHub Desktop.
Save EvilFreelancer/3f5166ff9b6de1e4adccf6ea192ab9ca to your computer and use it in GitHub Desktop.
Пример локального запуска модели NLLB (перевод rus -> eng)
LANGUAGES = [
{"long": "afr_Latn", "short": "af", "name": "Afrikaans"},
{"long": "als_Latn", "short": "sq", "name": "Albanian"},
{"long": "amh_Ethi", "short": "am", "name": "Amharic"},
{"long": "arb_Arab", "short": "ar", "name": "Arabic"},
{"long": "ast_Latn", "short": "ast", "name": "Asturian"},
{"long": "azj_Latn", "short": "az", "name": "Azerbaijani"},
{"long": "bel_Cyrl", "short": "be", "name": "Belarusian"},
{"long": "ben_Beng", "short": "bn", "name": "Bengali"},
{"long": "bul_Cyrl", "short": "bg", "name": "Bulgarian"},
{"long": "cat_Latn", "short": "ca", "name": "Catalan"},
{"long": "ceb_Latn", "short": "ceb", "name": "Cebuano"},
{"long": "ces_Latn", "short": "cs", "name": "Czech"},
{"long": "cym_Latn", "short": "cy", "name": "Welsh"},
{"long": "dan_Latn", "short": "da", "name": "Danish"},
{"long": "deu_Latn", "short": "de", "name": "German"},
{"long": "ell_Grek", "short": "el", "name": "Greek"},
{"long": "eng_Latn", "short": "en", "name": "English"},
{"long": "epo_Latn", "short": "eo", "name": "Esperanto"},
{"long": "est_Latn", "short": "et", "name": "Estonian"},
{"long": "fin_Latn", "short": "fi", "name": "Finnish"},
{"long": "fra_Latn", "short": "fr", "name": "French"},
{"long": "gaz_Latn", "short": "om", "name": "Oromo"},
{"long": "gla_Latn", "short": "gd", "name": "Scottish Gaelic"},
{"long": "gle_Latn", "short": "ga", "name": "Irish"},
{"long": "glg_Latn", "short": "gl", "name": "Galician"},
{"long": "hau_Latn", "short": "ha", "name": "Hausa"},
{"long": "heb_Hebr", "short": "he", "name": "Hebrew"},
{"long": "hin_Deva", "short": "hi", "name": "Hindi"},
{"long": "hrv_Latn", "short": "hr", "name": "Croatian"},
{"long": "hun_Latn", "short": "hu", "name": "Hungarian"},
{"long": "hye_Armn", "short": "hy", "name": "Armenian"},
{"long": "ibo_Latn", "short": "ig", "name": "Igbo"},
{"long": "ilo_Latn", "short": "ilo", "name": "Ilocano"},
{"long": "ind_Latn", "short": "id", "name": "Indonesian"},
{"long": "isl_Latn", "short": "is", "name": "Icelandic"},
{"long": "ita_Latn", "short": "it", "name": "Italian"},
{"long": "jav_Latn", "short": "jv", "name": "Javanese"},
{"long": "jpn_Jpan", "short": "ja", "name": "Japanese"},
{"long": "kat_Geor", "short": "ka", "name": "Georgian"},
{"long": "kaz_Cyrl", "short": "kk", "name": "Kazakh"},
{"long": "khm_Khmr", "short": "km", "name": "Khmer"},
{"long": "kor_Hang", "short": "ko", "name": "Korean"},
{"long": "lit_Latn", "short": "lt", "name": "Lithuanian"},
{"long": "ltz_Latn", "short": "lb", "name": "Luxembourgish"},
{"long": "lug_Latn", "short": "lg", "name": "Ganda"},
{"long": "lvs_Latn", "short": "lv", "name": "Latvian"},
{"long": "mal_Mlym", "short": "ml", "name": "Malayalam"},
{"long": "mar_Deva", "short": "mr", "name": "Marathi"},
{"long": "mkd_Cyrl", "short": "mk", "name": "Macedonian"},
{"long": "mya_Mymr", "short": "my", "name": "Burmese"},
{"long": "nld_Latn", "short": "nl", "name": "Dutch"},
{"long": "nob_Latn", "short": "no", "name": "Norwegian Bokmål"},
{"long": "npi_Deva", "short": "ne", "name": "Nepali"},
{"long": "oci_Latn", "short": "oc", "name": "Occitan"},
{"long": "ory_Orya", "short": "or", "name": "Odia"},
{"long": "pes_Arab", "short": "fa", "name": "Persian"},
{"long": "plt_Latn", "short": "mg", "name": "Malagasy"},
{"long": "pol_Latn", "short": "pl", "name": "Polish"},
{"long": "por_Latn", "short": "pt", "name": "Portuguese"},
{"long": "ron_Latn", "short": "ro", "name": "Romanian"},
{"long": "rus_Cyrl", "short": "ru", "name": "Russian"},
{"long": "sin_Sinh", "short": "si", "name": "Sinhala"},
{"long": "slk_Latn", "short": "sk", "name": "Slovak"},
{"long": "slv_Latn", "short": "sl", "name": "Slovenian"},
{"long": "snd_Arab", "short": "sd", "name": "Sindhi"},
{"long": "som_Latn", "short": "so", "name": "Somali"},
{"long": "spa_Latn", "short": "es", "name": "Spanish"},
{"long": "srp_Cyrl", "short": "sr", "name": "Serbian"},
{"long": "sun_Latn", "short": "su", "name": "Sundanese"},
{"long": "swe_Latn", "short": "sv", "name": "Swedish"},
{"long": "swh_Latn", "short": "sw", "name": "Swahili"},
{"long": "tam_Taml", "short": "ta", "name": "Tamil"},
{"long": "tat_Cyrl", "short": "tt", "name": "Tatar"},
{"long": "tgl_Latn", "short": "tl", "name": "Tagalog"},
{"long": "tur_Latn", "short": "tr", "name": "Turkish"},
{"long": "ukr_Cyrl", "short": "uk", "name": "Ukrainian"},
{"long": "urd_Arab", "short": "ur", "name": "Urdu"},
{"long": "uzn_Latn", "short": "uz", "name": "Uzbek"},
{"long": "vie_Latn", "short": "vi", "name": "Vietnamese"},
{"long": "wol_Latn", "short": "wo", "name": "Wolof"},
{"long": "xho_Latn", "short": "xh", "name": "Xhosa"},
{"long": "ydd_Hebr", "short": "yi", "name": "Yiddish"},
{"long": "yor_Latn", "short": "yo", "name": "Yoruba"},
{"long": "zho_Hans", "short": "zh", "name": "Chinese (Simplified)"},
{"long": "zsm_Latn", "short": "ms", "name": "Malay"},
{"long": "zul_Latn", "short": "zu", "name": "Zulu"},
]
SHORT_TO_LONG = {lang['short']: lang['long'] for lang in LANGUAGES}
LONG_TO_SHORT = {lang['long']: lang['short'] for lang in LANGUAGES}
def get_language_long(short):
"""Get a long language name by its short name"""
return SHORT_TO_LONG.get(short)
def get_language_short(long):
"""Get a short language name by its long name"""
return LONG_TO_SHORT.get(long)
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
model_name = "facebook/nllb-200-distilled-600M"
from_lang = "rus_Cyrl"
to_lang = "eng_Latn"
tokenizer = AutoTokenizer.from_pretrained(model_name, token=True, src_lang=from_lang)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, token=True)
article = "Всем привет! У микрофона Павел Злой и сегодня мы поговорим о..."
inputs = tokenizer(article, return_tensors="pt")
translated_tokens = model.generate(
**inputs,
forced_bos_token_id=tokenizer.lang_code_to_id[to_lang],
max_length=30
)
tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from langdetect import detect
from nllb_languages import get_language_long
DEFAULT_MODEL = "facebook/nllb-200-distilled-600M"
DEFAULT_MAX_LENGTH = 1024
class NLLBModel:
def __init__(self, model_name=DEFAULT_MODEL):
self.model_name = model_name
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name, token=True)
def detect(self, text: str) -> str:
detected_language = detect(text)
return get_language_long(detected_language)
def translate(self, from_lng: str, to_lng: str, text: str, max_length=DEFAULT_MAX_LENGTH) -> str:
tokenizer = AutoTokenizer.from_pretrained(self.model_name, token=True, src_lang=from_lng)
inputs = tokenizer(text, return_tensors="pt")
translated_tokens = self.model.generate(
**inputs,
forced_bos_token_id=tokenizer.lang_code_to_id[to_lng],
max_length=max_length
)
return tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment