Last active
February 3, 2024 13:32
-
-
Save EvilFreelancer/3f5166ff9b6de1e4adccf6ea192ab9ca to your computer and use it in GitHub Desktop.
Пример локального запуска модели NLLB (перевод rus -> eng)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
LANGUAGES = [ | |
{"long": "afr_Latn", "short": "af", "name": "Afrikaans"}, | |
{"long": "als_Latn", "short": "sq", "name": "Albanian"}, | |
{"long": "amh_Ethi", "short": "am", "name": "Amharic"}, | |
{"long": "arb_Arab", "short": "ar", "name": "Arabic"}, | |
{"long": "ast_Latn", "short": "ast", "name": "Asturian"}, | |
{"long": "azj_Latn", "short": "az", "name": "Azerbaijani"}, | |
{"long": "bel_Cyrl", "short": "be", "name": "Belarusian"}, | |
{"long": "ben_Beng", "short": "bn", "name": "Bengali"}, | |
{"long": "bul_Cyrl", "short": "bg", "name": "Bulgarian"}, | |
{"long": "cat_Latn", "short": "ca", "name": "Catalan"}, | |
{"long": "ceb_Latn", "short": "ceb", "name": "Cebuano"}, | |
{"long": "ces_Latn", "short": "cs", "name": "Czech"}, | |
{"long": "cym_Latn", "short": "cy", "name": "Welsh"}, | |
{"long": "dan_Latn", "short": "da", "name": "Danish"}, | |
{"long": "deu_Latn", "short": "de", "name": "German"}, | |
{"long": "ell_Grek", "short": "el", "name": "Greek"}, | |
{"long": "eng_Latn", "short": "en", "name": "English"}, | |
{"long": "epo_Latn", "short": "eo", "name": "Esperanto"}, | |
{"long": "est_Latn", "short": "et", "name": "Estonian"}, | |
{"long": "fin_Latn", "short": "fi", "name": "Finnish"}, | |
{"long": "fra_Latn", "short": "fr", "name": "French"}, | |
{"long": "gaz_Latn", "short": "om", "name": "Oromo"}, | |
{"long": "gla_Latn", "short": "gd", "name": "Scottish Gaelic"}, | |
{"long": "gle_Latn", "short": "ga", "name": "Irish"}, | |
{"long": "glg_Latn", "short": "gl", "name": "Galician"}, | |
{"long": "hau_Latn", "short": "ha", "name": "Hausa"}, | |
{"long": "heb_Hebr", "short": "he", "name": "Hebrew"}, | |
{"long": "hin_Deva", "short": "hi", "name": "Hindi"}, | |
{"long": "hrv_Latn", "short": "hr", "name": "Croatian"}, | |
{"long": "hun_Latn", "short": "hu", "name": "Hungarian"}, | |
{"long": "hye_Armn", "short": "hy", "name": "Armenian"}, | |
{"long": "ibo_Latn", "short": "ig", "name": "Igbo"}, | |
{"long": "ilo_Latn", "short": "ilo", "name": "Ilocano"}, | |
{"long": "ind_Latn", "short": "id", "name": "Indonesian"}, | |
{"long": "isl_Latn", "short": "is", "name": "Icelandic"}, | |
{"long": "ita_Latn", "short": "it", "name": "Italian"}, | |
{"long": "jav_Latn", "short": "jv", "name": "Javanese"}, | |
{"long": "jpn_Jpan", "short": "ja", "name": "Japanese"}, | |
{"long": "kat_Geor", "short": "ka", "name": "Georgian"}, | |
{"long": "kaz_Cyrl", "short": "kk", "name": "Kazakh"}, | |
{"long": "khm_Khmr", "short": "km", "name": "Khmer"}, | |
{"long": "kor_Hang", "short": "ko", "name": "Korean"}, | |
{"long": "lit_Latn", "short": "lt", "name": "Lithuanian"}, | |
{"long": "ltz_Latn", "short": "lb", "name": "Luxembourgish"}, | |
{"long": "lug_Latn", "short": "lg", "name": "Ganda"}, | |
{"long": "lvs_Latn", "short": "lv", "name": "Latvian"}, | |
{"long": "mal_Mlym", "short": "ml", "name": "Malayalam"}, | |
{"long": "mar_Deva", "short": "mr", "name": "Marathi"}, | |
{"long": "mkd_Cyrl", "short": "mk", "name": "Macedonian"}, | |
{"long": "mya_Mymr", "short": "my", "name": "Burmese"}, | |
{"long": "nld_Latn", "short": "nl", "name": "Dutch"}, | |
{"long": "nob_Latn", "short": "no", "name": "Norwegian Bokmål"}, | |
{"long": "npi_Deva", "short": "ne", "name": "Nepali"}, | |
{"long": "oci_Latn", "short": "oc", "name": "Occitan"}, | |
{"long": "ory_Orya", "short": "or", "name": "Odia"}, | |
{"long": "pes_Arab", "short": "fa", "name": "Persian"}, | |
{"long": "plt_Latn", "short": "mg", "name": "Malagasy"}, | |
{"long": "pol_Latn", "short": "pl", "name": "Polish"}, | |
{"long": "por_Latn", "short": "pt", "name": "Portuguese"}, | |
{"long": "ron_Latn", "short": "ro", "name": "Romanian"}, | |
{"long": "rus_Cyrl", "short": "ru", "name": "Russian"}, | |
{"long": "sin_Sinh", "short": "si", "name": "Sinhala"}, | |
{"long": "slk_Latn", "short": "sk", "name": "Slovak"}, | |
{"long": "slv_Latn", "short": "sl", "name": "Slovenian"}, | |
{"long": "snd_Arab", "short": "sd", "name": "Sindhi"}, | |
{"long": "som_Latn", "short": "so", "name": "Somali"}, | |
{"long": "spa_Latn", "short": "es", "name": "Spanish"}, | |
{"long": "srp_Cyrl", "short": "sr", "name": "Serbian"}, | |
{"long": "sun_Latn", "short": "su", "name": "Sundanese"}, | |
{"long": "swe_Latn", "short": "sv", "name": "Swedish"}, | |
{"long": "swh_Latn", "short": "sw", "name": "Swahili"}, | |
{"long": "tam_Taml", "short": "ta", "name": "Tamil"}, | |
{"long": "tat_Cyrl", "short": "tt", "name": "Tatar"}, | |
{"long": "tgl_Latn", "short": "tl", "name": "Tagalog"}, | |
{"long": "tur_Latn", "short": "tr", "name": "Turkish"}, | |
{"long": "ukr_Cyrl", "short": "uk", "name": "Ukrainian"}, | |
{"long": "urd_Arab", "short": "ur", "name": "Urdu"}, | |
{"long": "uzn_Latn", "short": "uz", "name": "Uzbek"}, | |
{"long": "vie_Latn", "short": "vi", "name": "Vietnamese"}, | |
{"long": "wol_Latn", "short": "wo", "name": "Wolof"}, | |
{"long": "xho_Latn", "short": "xh", "name": "Xhosa"}, | |
{"long": "ydd_Hebr", "short": "yi", "name": "Yiddish"}, | |
{"long": "yor_Latn", "short": "yo", "name": "Yoruba"}, | |
{"long": "zho_Hans", "short": "zh", "name": "Chinese (Simplified)"}, | |
{"long": "zsm_Latn", "short": "ms", "name": "Malay"}, | |
{"long": "zul_Latn", "short": "zu", "name": "Zulu"}, | |
] | |
SHORT_TO_LONG = {lang['short']: lang['long'] for lang in LANGUAGES} | |
LONG_TO_SHORT = {lang['long']: lang['short'] for lang in LANGUAGES} | |
def get_language_long(short): | |
"""Get a long language name by its short name""" | |
return SHORT_TO_LONG.get(short) | |
def get_language_short(long): | |
"""Get a short language name by its long name""" | |
return LONG_TO_SHORT.get(long) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
model_name = "facebook/nllb-200-distilled-600M" | |
from_lang = "rus_Cyrl" | |
to_lang = "eng_Latn" | |
tokenizer = AutoTokenizer.from_pretrained(model_name, token=True, src_lang=from_lang) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, token=True) | |
article = "Всем привет! У микрофона Павел Злой и сегодня мы поговорим о..." | |
inputs = tokenizer(article, return_tensors="pt") | |
translated_tokens = model.generate( | |
**inputs, | |
forced_bos_token_id=tokenizer.lang_code_to_id[to_lang], | |
max_length=30 | |
) | |
tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
from langdetect import detect | |
from nllb_languages import get_language_long | |
DEFAULT_MODEL = "facebook/nllb-200-distilled-600M" | |
DEFAULT_MAX_LENGTH = 1024 | |
class NLLBModel: | |
def __init__(self, model_name=DEFAULT_MODEL): | |
self.model_name = model_name | |
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name, token=True) | |
def detect(self, text: str) -> str: | |
detected_language = detect(text) | |
return get_language_long(detected_language) | |
def translate(self, from_lng: str, to_lng: str, text: str, max_length=DEFAULT_MAX_LENGTH) -> str: | |
tokenizer = AutoTokenizer.from_pretrained(self.model_name, token=True, src_lang=from_lng) | |
inputs = tokenizer(text, return_tensors="pt") | |
translated_tokens = self.model.generate( | |
**inputs, | |
forced_bos_token_id=tokenizer.lang_code_to_id[to_lng], | |
max_length=max_length | |
) | |
return tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment