""" File: model_translation.py Description: Loading models for text translations Author: Didier Guillevic Date: 2024-03-16 """ import spaces import logging logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration from transformers import BitsAndBytesConfig from model_spacy import nlp_xx as model_spacy quantization_config = BitsAndBytesConfig( load_in_8bit=True, llm_int8_threshold=200.0 # https://discuss.huggingface.co/t/correct-usage-of-bitsandbytesconfig/33809/5 ) # The 100 languages supported by the facebook/m2m100_418M model # https://huggingface.co/facebook/m2m100_418M # plus the 'AUTOMATIC' option where we will use a language detector. language_codes = { 'AUTOMATIC': 'auto', 'Afrikaans (af)': 'af', 'Albanian (sq)': 'sq', 'Amharic (am)': 'am', 'Arabic (ar)': 'ar', 'Armenian (hy)': 'hy', 'Asturian (ast)': 'ast', 'Azerbaijani (az)': 'az', 'Bashkir (ba)': 'ba', 'Belarusian (be)': 'be', 'Bengali (bn)': 'bn', 'Bosnian (bs)': 'bs', 'Breton (br)': 'br', 'Bulgarian (bg)': 'bg', 'Burmese (my)': 'my', 'Catalan; Valencian (ca)': 'ca', 'Cebuano (ceb)': 'ceb', 'Central Khmer (km)': 'km', 'Chinese (zh)': 'zh', 'Croatian (hr)': 'hr', 'Czech (cs)': 'cs', 'Danish (da)': 'da', 'Dutch; Flemish (nl)': 'nl', 'English (en)': 'en', 'Estonian (et)': 'et', 'Finnish (fi)': 'fi', 'French (fr)': 'fr', 'Fulah (ff)': 'ff', 'Gaelic; Scottish Gaelic (gd)': 'gd', 'Galician (gl)': 'gl', 'Ganda (lg)': 'lg', 'Georgian (ka)': 'ka', 'German (de)': 'de', 'Greeek (el)': 'el', 'Gujarati (gu)': 'gu', 'Haitian; Haitian Creole (ht)': 'ht', 'Hausa (ha)': 'ha', 'Hebrew (he)': 'he', 'Hindi (hi)': 'hi', 'Hungarian (hu)': 'hu', 'Icelandic (is)': 'is', 'Igbo (ig)': 'ig', 'Iloko (ilo)': 'ilo', 'Indonesian (id)': 'id', 'Irish (ga)': 'ga', 'Italian (it)': 'it', 'Japanese (ja)': 'ja', 'Javanese (jv)': 'jv', 'Kannada (kn)': 'kn', 'Kazakh (kk)': 'kk', 'Korean (ko)': 'ko', 'Lao (lo)': 'lo', 'Latvian (lv)': 'lv', 'Lingala (ln)': 'ln', 'Lithuanian (lt)': 'lt', 'Luxembourgish; Letzeburgesch (lb)': 'lb', 'Macedonian (mk)': 'mk', 'Malagasy (mg)': 'mg', 'Malay (ms)': 'ms', 'Malayalam (ml)': 'ml', 'Marathi (mr)': 'mr', 'Mongolian (mn)': 'mn', 'Nepali (ne)': 'ne', 'Northern Sotho (ns)': 'ns', 'Norwegian (no)': 'no', 'Occitan (post 1500) (oc)': 'oc', 'Oriya (or)': 'or', 'Panjabi; Punjabi (pa)': 'pa', 'Persian (fa)': 'fa', 'Polish (pl)': 'pl', 'Portuguese (pt)': 'pt', 'Pushto; Pashto (ps)': 'ps', 'Romanian; Moldavian; Moldovan (ro)': 'ro', 'Russian (ru)': 'ru', 'Serbian (sr)': 'sr', 'Sindhi (sd)': 'sd', 'Sinhala; Sinhalese (si)': 'si', 'Slovak (sk)': 'sk', 'Slovenian (sl)': 'sl', 'Somali (so)': 'so', 'Spanish (es)': 'es', 'Sundanese (su)': 'su', 'Swahili (sw)': 'sw', 'Swati (ss)': 'ss', 'Swedish (sv)': 'sv', 'Tagalog (tl)': 'tl', 'Tamil (ta)': 'ta', 'Thai (th)': 'th', 'Tswana (tn)': 'tn', 'Turkish (tr)': 'tr', 'Ukrainian (uk)': 'uk', 'Urdu (ur)': 'ur', 'Uzbek (uz)': 'uz', 'Vietnamese (vi)': 'vi', 'Welsh (cy)': 'cy', 'Western Frisian (fy)': 'fy', 'Wolof (wo)': 'wo', 'Xhosa (xh)': 'xh', 'Yiddish (yi)': 'yi', 'Yoruba (yo)': 'yo', 'Zulu (zu)': 'zu' } tgt_language_codes = { 'English (en)': 'en', 'French (fr)': 'fr' } def build_text_chunks( text: str, sents_per_chunk: int=5, words_per_chunk=200) -> list[str]: """Split a given text into chunks with at most sents_per_chnks and words_per_chunk Given a text: - Split the text into sentences. - Build text chunks: - Consider up to sents_per_chunk - Ensure that we do not exceed words_per_chunk """ # Split text into sentences... sentences = [ sent.text.strip() for sent in model_spacy(text).sents if sent.text.strip() ] logger.info(f"TEXT: {text[:25]}, NB_SENTS: {len(sentences)}") # Create text chunks of N sentences chunks = [] chunk = '' chunk_nb_sentences = 0 chunk_nb_words = 0 for i in range(0, len(sentences)): # Get sentence sent = sentences[i] sent_nb_words = len(sent.split()) # If chunk already 'full', save chunk, start new chunk if ( (chunk_nb_words + sent_nb_words > words_per_chunk) or (chunk_nb_sentences + 1 > sents_per_chunk) ): chunks.append(chunk) chunk = '' chunk_nb_sentences = 0 chunk_nb_words = 0 # Append sentence to current chunk. One sentence per line. chunk = (chunk + '\n' + sent) if chunk else sent chunk_nb_sentences += 1 chunk_nb_words += sent_nb_words # Append last chunk if chunk: chunks.append(chunk) return chunks class Singleton(type): _instances = {} def __call__(cls, *args, **kwargs): if cls not in cls._instances: cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) return cls._instances[cls] class ModelM2M100(metaclass=Singleton): """Loads an instance of the M2M100 model. Model: https://huggingface.co/facebook/m2m100_1.2B """ def __init__(self): self._model_name = "facebook/m2m100_418M" self._tokenizer = M2M100Tokenizer.from_pretrained(self._model_name) self._model = M2M100ForConditionalGeneration.from_pretrained( self._model_name, device_map="auto", torch_dtype=torch.float16, low_cpu_mem_usage=True #quantization_config=quantization_config ) self._model = torch.compile(self._model) @spaces.GPU def translate( self, text: str, src_lang: str, tgt_lang: str, chunk_text: bool=True, sents_per_chunk: int=5, words_per_chunk: int=200 ) -> str: """Translate the given text from src_lang to tgt_lang. The text will be split into chunks to ensure the chunks fit into the model input_max_length (usually 512 tokens). """ chunks = [text,] if chunk_text: chunks = build_text_chunks(text, sents_per_chunk, words_per_chunk) self._tokenizer.src_lang = src_lang translated_chunks = [] for chunk in chunks: input_ids = self._tokenizer( chunk, return_tensors="pt").input_ids.to(self._model.device) outputs = self._model.generate( input_ids=input_ids, forced_bos_token_id=self._tokenizer.get_lang_id(tgt_lang)) translated_chunk = self._tokenizer.batch_decode( outputs, skip_special_tokens=True)[0] translated_chunks.append(translated_chunk) return '\n'.join(translated_chunks) @property def model_name(self): return self._model_name @property def tokenizer(self): return self._tokenizer @property def model(self): return self._model @property def device(self): return self._model.device class ModelMADLAD(metaclass=Singleton): """Loads an instance of the Google MADLAD model (3B). Model: https://huggingface.co/google/madlad400-3b-mt """ def __init__(self): self._model_name = "google/madlad400-3b-mt" self._input_max_length = 512 # config.json n_positions self._output_max_length = 512 # config.json n_positions self._tokenizer = AutoTokenizer.from_pretrained( self.model_name, use_fast=True ) self._model = AutoModelForSeq2SeqLM.from_pretrained( self._model_name, device_map="auto", torch_dtype=torch.float16, low_cpu_mem_usage=True, quantization_config=quantization_config ) self._model = torch.compile(self._model) @spaces.GPU def translate( self, text: str, tgt_lang: str, chunk_text: True, sents_per_chunk: int=5, words_per_chunk: int=5 ) -> str: """Translate given text into the target language. The text will be split into chunks to ensure the chunks fit into the model input_max_length (usually 512 tokens). """ chunks = [text,] if chunk_text: chunks = build_text_chunks(text, sents_per_chunk, words_per_chunk) translated_chunks = [] for chunk in chunks: input_text = f"<2{tgt_lang}> {chunk}" logger.info(f" Translating: {input_text[:50]}") input_ids = self._tokenizer( input_text, return_tensors="pt", max_length=self._input_max_length, truncation=True, padding="longest").input_ids.to(self._model.device) outputs = self._model.generate( input_ids=input_ids, max_length=self._output_max_length) translated_chunk = self._tokenizer.decode( outputs[0], skip_special_tokens=True) translated_chunks.append(translated_chunk) return '\n'.join(translated_chunks) @property def model_name(self): return self._model_name @property def tokenizer(self): return self._tokenizer @property def model(self): return self._model @property def device(self): return self._model.device # Bi-lingual individual models src_langs = set(["ar", "en", "fa", "fr", "he", "ja", "zh"]) model_names = { "ar": "Helsinki-NLP/opus-mt-ar-en", "en": "Helsinki-NLP/opus-mt-en-fr", "fa": "Helsinki-NLP/opus-mt-tc-big-fa-itc", "fr": "Helsinki-NLP/opus-mt-fr-en", "he": "Helsinki-NLP/opus-mt-tc-big-he-en", "zh": "Helsinki-NLP/opus-mt-zh-en", } # Registry for all loaded bilingual models tokenizer_model_registry = {} device = 'cpu' def get_tokenizer_model_for_src_lang(src_lang: str) -> (AutoTokenizer, AutoModelForSeq2SeqLM): """ Return the (tokenizer, model) for a given source language. """ src_lang = src_lang.lower() # Already loaded? if src_lang in tokenizer_model_registry: return tokenizer_model_registry.get(src_lang) # Load tokenizer and model model_name = model_names.get(src_lang) if not model_name: raise Exception(f"No model defined for language: {src_lang}") # We will leave the models on the CPU (for now) tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSeq2SeqLM.from_pretrained(model_name) if model.config.torch_dtype != torch.float16: model = model.half() model.to(device) tokenizer_model_registry[src_lang] = (tokenizer, model) return (tokenizer, model) # Max number of words for given input text # - Usually 512 tokens (max position encodings, as well as max length) # - Let's set to some number of words somewhat lower than that threshold # - e.g. 200 words max_words_per_chunk = 200