{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "23e98a8a-7128-4f35-ba1c-ff514ed462e0", "metadata": {}, "outputs": [], "source": [ "#Install All the Required Dependencies\n", "#!pip3 install torch torchvision torchaudio\n", "#!pip install transformers ipywidgets gradio --upgrade\n", "#!pip install --upgrade transformers accelerate\n", "#!pip install --upgrade gradio\n", "#!pip install nltk\n", "#!pip install jiwer\n", "#!pip install sentencepiece\n", "#!pip install sacremoses\n", "#!pip install soundfile\n", "#!pip install librosa numpy jiwer nltk\n", "#!pip install --upgrade pip \n", "#!pip install huggingface_hub" ] }, { "cell_type": "code", "execution_count": 2, "id": "0d2a7d3a-8c2c-4134-a79f-a3b7b1747874", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2024-12-20 20:13:51.723870: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", "2024-12-20 20:13:51.767697: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2024-12-20 20:13:51.767728: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2024-12-20 20:13:51.768839: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", "2024-12-20 20:13:51.775965: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", "2024-12-20 20:13:52.795860: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" ] } ], "source": [ "#Import Required Libraries\n", "from transformers import pipeline\n", "from jiwer import wer\n", "from transformers import VitsModel, AutoTokenizer, set_seed\n", "import torch\n", "import soundfile as sf\n", "import librosa\n", "from scipy.spatial.distance import euclidean\n", "import numpy as np\n", "import string\n", "import os\n", "from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction\n", "from nltk.translate.meteor_score import meteor_score\n", "import string\n", "import numpy as np\n", "import librosa\n", "from scipy.spatial.distance import euclidean\n", "import string\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "e2bafb31-ecf6-44e4-b25a-24abfa75bed1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['/home/jupyter-prof-adetiba/nltk_data', '/opt/tljh/user/nltk_data', '/opt/tljh/user/share/nltk_data', '/opt/tljh/user/lib/nltk_data', '/usr/share/nltk_data', '/usr/local/share/nltk_data', '/usr/lib/nltk_data', '/usr/local/lib/nltk_data']\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[nltk_data] Downloading package wordnet to /home/jupyter-prof-\n", "[nltk_data] adetiba/nltk_data...\n", "[nltk_data] Package wordnet is already up-to-date!\n", "[nltk_data] Downloading package omw-1.4 to /home/jupyter-prof-\n", "[nltk_data] adetiba/nltk_data...\n", "[nltk_data] Package omw-1.4 is already up-to-date!\n" ] } ], "source": [ "import nltk\n", "nltk.download('wordnet')\n", "nltk.download('omw-1.4') # Optional if using WordNet's multilingual features\n", "import nltk\n", "print(nltk.data.path)\n", "import nltk\n", "nltk.data.path.append('./nltk_data')" ] }, { "cell_type": "code", "execution_count": 4, "id": "10ceb8b4-fe4e-4a97-ac34-dce6a890455a", "metadata": {}, "outputs": [], "source": [ "#Define all Utility Functions\n", "# Function to compute BLEU score\n", "def compute_bleu(reference_text, predicted_text):\n", " \"\"\"\n", " Computes the BLEU score for a single translation.\n", " :param reference_text: The ground truth text (in Yoruba).\n", " :param predicted_text: The machine-generated translation text (in Yoruba).\n", " :return: BLEU score (float).\n", " \"\"\"\n", " print(\"The Reference Text = \", reference_text)\n", " print(\"The Predicted Text = \",predicted_text)\n", " # Tokenize the reference and predicted texts\n", " reference_tokens = [reference_text.split()] # Reference should be wrapped in a list\n", " predicted_tokens = predicted_text.split()\n", "\n", " # Add smoothing to handle cases with few n-gram matches\n", " smoothing_function = SmoothingFunction().method1\n", "\n", " # Compute BLEU score\n", " bleu_score = sentence_bleu(reference_tokens, predicted_tokens, smoothing_function=smoothing_function)\n", " #print(\"The Computed bleu_score in the Compute_Blue Fn = \",bleu_score)\n", " return round(bleu_score,2)\n", "# Function to compute Word Error Rate (WER)\n", "def compute_wer(reference_text, predicted_text):\n", " \"\"\"\n", " Computes the Word Error Rate (WER) for a single translation.\n", " :param reference_text: The ground truth text (in Yoruba).\n", " :param predicted_text: The machine-generated translation text (in Yoruba).\n", " :return: WER score (float).\n", " \"\"\"\n", " # Normalize text: lowercase and remove punctuation\n", " reference_text = reference_text.lower().translate(str.maketrans('', '', string.punctuation))\n", " predicted_text = predicted_text.lower().translate(str.maketrans('', '', string.punctuation))\n", "\n", " # Compute WER\n", " wer_score = wer(reference_text, predicted_text)\n", "\n", " return round(wer_score,2)\n", "\n", "# Function to compute METEOR score\n", "def compute_meteor(reference_text, predicted_text):\n", " \"\"\"\n", " Computes the METEOR score for a single translation.\n", " :param reference_text: The ground truth text (in Yoruba).\n", " :param predicted_text: The machine-generated translation text (in Yoruba).\n", " :return: METEOR score (float).\n", " \"\"\"\n", " # Normalize text: lowercase and remove punctuation\n", " reference_text = reference_text.lower().translate(str.maketrans('', '', string.punctuation))\n", " predicted_text = predicted_text.lower().translate(str.maketrans('', '', string.punctuation))\n", "\n", " # Tokenize text into lists of words\n", " reference_tokens = reference_text.split()\n", " predicted_tokens = predicted_text.split()\n", "\n", " # Compute METEOR score\n", " meteor = meteor_score([reference_tokens], predicted_tokens)\n", " \n", " return round(meteor,2)\n", "\n", "# Function to compute Mel Cepstral Distance (MCD)\n", "def compute_mcd(ground_truth_audio_path, predicted_audio_path):\n", " \"\"\"\n", " Computes the Mel Cepstral Distance (MCD) between two audio files.\n", " :param ground_truth_audio_path: Path to the ground truth audio file.\n", " :param predicted_audio_path: Path to the predicted audio file.\n", " :return: MCD score (float).\n", " \"\"\"\n", " # Load audio files\n", " y_true, sr_true = librosa.load(ground_truth_audio_path, sr=16000)\n", " y_pred, sr_pred = librosa.load(predicted_audio_path, sr=16000)\n", "\n", " # Ensure the sampling rates match\n", " assert sr_true == sr_pred, \"Sampling rates do not match between audio files.\"\n", "\n", " # Compute MFCCs\n", " mfcc_true = librosa.feature.mfcc(y=y_true, sr=sr_true, n_mfcc=13).T\n", " mfcc_pred = librosa.feature.mfcc(y=y_pred, sr=sr_pred, n_mfcc=13).T\n", "\n", " # Align the MFCC frames\n", " min_frames = min(len(mfcc_true), len(mfcc_pred))\n", " mfcc_true = mfcc_true[:min_frames]\n", " mfcc_pred = mfcc_pred[:min_frames]\n", "\n", " # Compute the Euclidean distance for each frame and average\n", " mcd = 0.0\n", " for i in range(min_frames):\n", " mcd += euclidean(mfcc_true[i], mfcc_pred[i])\n", " mcd = (10.0 / np.log(10)) * (mcd / min_frames)\n", "\n", " return round(mcd,2)" ] }, { "cell_type": "code", "execution_count": 5, "id": "69d64db9-b083-46ae-80ce-9616ba99183d", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [], "source": [ "#Define Translation and Synthesis Function\n", "def translate_transformers(modelName, sourceLangText):\n", " #results = translation_pipeline(input_text)\n", " translation_pipeline = pipeline('translation_en_to_yo', model = modelName, max_length=500)\n", " translated_text = translation_pipeline(sourceLangText) #translator(text)[0][\"translation_text\"]\n", " translated_text_target = translated_text[0]['translation_text']\n", " #reference_translations = \"awon apositeli, awon woli, awon ajinrere ati awon oluso agutan ati awon oluko.\" #'recorder_2024-01-13_11-24-41_453538.wav'#\"My name is Joy, I love reading\"\n", " \n", " #TTS for the translated_text_target\n", " #TTS Exp1\n", " ttsModel = VitsModel.from_pretrained(\"facebook/mms-tts-yor\")\n", " tokenizer = AutoTokenizer.from_pretrained(\"facebook/mms-tts-yor\")\n", " ttsInputs = tokenizer(translated_text_target, return_tensors=\"pt\")\n", " set_seed(555) # make deterministic\n", " with torch.no_grad():\n", " ttsOutput = ttsModel(**ttsInputs).waveform\n", " #Convert the tensor to a numpy array\n", " ttsWaveform = ttsOutput.numpy()[0] \n", " #Save the waveform to an audio file\n", " #sf.write('output.wav', waveform, 22050)\n", " sf.write('ttsOutput.wav', ttsWaveform, 16000)\n", " \n", " # Sample ground truth and predicted text2text translations for Clinical Text\n", " #ground_truth_text = \"Àrùn jẹjẹrẹ ọmú jẹ́ ọ̀kan pàtàkì lára ohun tó ń ṣàkóbá fún ìlera gbogbo ènìyàn ní Nàìjíríà, ó sì jẹ́ ọ̀kan pàtàkì lára ohun tó ń fa ikú àwọn obìnrin tí àrùn jẹjẹrẹ ń pa lórílẹ̀-èdè náà.\"\n", " #predicted_text = translated_text_target #\" breast cancer is a\"\n", "\n", " # Sample ground truth and predicted text2text translations for News Text\n", " #ground_truth_text = \"Wọ́n ní ìgbà àkọ́kọ́ nìyí tí irú ìwà ipá bẹ́ẹ̀ máa wáyé ní ìpínlẹ̀ Ondo.\"\n", " #predicted_text = translated_text_target #\" breast cancer is a\"\n", "\n", " # Sample ground truth and predicted text2text translations for Religion Text\n", " ground_truth_text = \"Àwọn aposteli, àwọn wòlíì, àwọn ajíhìnrere, àwọn olùṣọ́-àgùntàn àti àwọn olùkọ́.\"\n", " predicted_text = translated_text_target #\" breast cancer is a\"\n", " \n", " #Compute bleu_score\n", " bleu_score = compute_bleu(ground_truth_text, predicted_text)\n", " print(f\"Bleu Score (BLEU): {bleu_score:.2f}\")\n", " \n", " #Compute WER\n", " wer_score = compute_wer(ground_truth_text, predicted_text)\n", " print(f\"Word Error Rate (WER): {wer_score:.2f}\")\n", "\n", " #Compute METEOR\n", " meteor = compute_meteor(ground_truth_text, predicted_text)\n", " print(f\"METEOR Score: {meteor:.2f}\")\n", "\n", " # Paths to sample audio files for MCD computation in current directory\n", " ground_truth_audio = os.path.join(os.getcwd(), \"gt_ttsOutput.wav\")\n", " predicted_audio = os.path.join(os.getcwd(), \"ttsOutput.wav\")\n", "\n", " # Compute Mel Cepstral Distance (MCD)\n", " try:\n", " mcd = compute_mcd(ground_truth_audio, predicted_audio)\n", " print(f\"Mel Cepstral Distance (MCD): {mcd:.2f}\")\n", " except Exception as e:\n", " print(f\"Error computing MCD: {e}\")\n", " \n", " return translated_text_target,bleu_score,wer_score,meteor,mcd,'ttsOutput.wav'" ] }, { "cell_type": "code", "execution_count": 6, "id": "bbf259d6-922d-4f5c-9af1-cbd57158a814", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [], "source": [ "#Define User Interface Function using Gradio and IPython Libraries\n", "import gradio as gr\n", "from IPython.display import Audio\n", "interface = gr.Interface(\n", " fn=translate_transformers,\n", " inputs=[\n", " gr.Dropdown([\"Davlan/byt5-base-eng-yor-mt\", #Exp1\n", " \"Davlan/m2m100_418M-eng-yor-mt\", #Exp2\n", " \"Davlan/mbart50-large-eng-yor-mt\", #Exp3\n", " \"Davlan/mt5_base_eng_yor_mt\", #Exp4\n", " \"omoekan/opus-tatoeba-eng-yor\", #Exp5\n", " \"masakhane/afrimt5_en_yor_news\", #Exp6\n", " \"masakhane/afrimbart_en_yor_news\", #Exp7\n", " \"masakhane/afribyt5_en_yor_news\", #Exp8\n", " \"masakhane/byt5_en_yor_news\", #Exp9\n", " \"masakhane/mt5_en_yor_news\", #Exp10\n", " \"masakhane/mbart50_en_yor_news\", #Exp11\n", " \"masakhane/m2m100_418M_en_yor_news\", #Exp12\n", " \"masakhane/m2m100_418M_en_yor_rel_news\", #Exp13\n", " \"masakhane/m2m100_418M_en_yor_rel_news_ft\", #Exp14\n", " \"masakhane/m2m100_418M_en_yor_rel\", #Exp15\n", " \"dabagyan/menyo_en2yo\", #Exp16\n", " #\"facebook/nllb-200-distilled-600M\", #Exp17\n", " #\"facebook/nllb-200-3.3B\", #Exp18\n", " #\"facebook/nllb-200-1.3B\", #Exp19\n", " #\"facebook/nllb-200-distilled-1.3B\", #Exp20\n", " #\"keithhon/nllb-200-3.3B\" #Exp21\n", " #\"CohereForAI/aya-101\" #Exp22\n", " \"facebook/m2m100_418M\", #Exp17\n", " #\"facebook/m2m100_1.2B\",#Exp18\n", " #\"facebook/m2m100-12B-avg-5-ckpt\", #Exp19\n", " \"google/mt5-base\", #Exp20\n", " \"google/byt5-large\" #Exp21\n", " ], \n", " label=\"Select Finetuned Eng2Yor Translation Model\"),\n", " gr.Textbox(lines=2, placeholder=\"Enter English Text Here...\", label=\"English Text\") \n", " ],\n", " #outputs = \"text\",\n", " #outputs=outputs=[\"text\", \"text\"],#\"text\"\n", " #outputs= gr.Textbox(value=\"text\", label=\"Translated Text\"),\n", " outputs=[\n", " gr.Textbox(value=\"text\", label=\"Translated Yoruba Text\"),\n", " #gr.Textbox(value=\"text\", label=translated_text_actual),\n", " gr.Textbox(value=\"number\", label=\"BLEU SCORE\"),\n", " gr.Textbox(value=\"number\", label=\"WER(WORD ERROR RATE) SCORE - The Lower the Better\"),\n", " gr.Textbox(value=\"number\", label=\"METEOR SCORE\"),\n", " gr.Textbox(value=\"number\", label=\"MCD(MEL CESPRAL DISTANCE) SCORE\"),\n", " gr.Audio(type=\"filepath\", label=\"Click to Generate Yoruba Speech from the Translated Text\")\n", " ],\n", " title=\"ASPMIR-MACHINE-TRANSLATION-TESTBED FOR LOW RESOURCED AFRICAN LANGUAGES\",\n", " #gr.Markdown(\"**This Tool Allows Developers and Researchers to Carry Out Experiments on Low Resourced African Languages with State-of-the-Art NMT Finetuned Models.**\"),\n", " description=\"{This Tool Allows Developers and Researchers to Carry Out Experiments on Low Resourced African Languages with State-of-the-Art Pretrained or Finetuned Models.}\"\n", ")\n", "#interface.launch(share=True)\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "c3baee0f-fd85-4209-9d54-14451abd372a", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "* Running on local URL: http://127.0.0.1:7860\n", "* Running on public URL: https://c18533aae56f5e43a5.gradio.live\n", "\n", "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "if __name__ == \"__main__\":\n", " interface.launch(share=True)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.10" } }, "nbformat": 4, "nbformat_minor": 5 }