import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig import torch # Model initialization repo_name = "BeardedMonster/SabiYarn-125M-translate" # Model repository tokenizer_name = "BeardedMonster/SabiYarn-125M" # Tokenizer repository # Load the model and tokenizer model = AutoModelForCausalLM.from_pretrained(repo_name, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True) # Move model to GPU if available device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) # Define generation configuration generation_config = GenerationConfig( max_length=100, # Adjust this based on your translation requirements max_new_tokens=50, # Ensure sufficient tokens for your translations num_beams=5, # Moderate number of beams for a balance between speed and quality do_sample=False, # Disable sampling to make output deterministic temperature=1.0, # Neutral temperature since sampling is off top_k=0, # Disable top-k sampling (since sampling is off) top_p=0, # Disable top-p (nucleus) sampling (since sampling is off) repetition_penalty=4.0, # Neutral repetition penalty for translation length_penalty=3.0, # No penalty for sequence length; modify if your translations tend to be too short/long early_stopping=True # Stop early when all beams finish to speed up generation ) def generate_text(prompt, language): # Add translation tag to prompt tagged_prompt = f" <{language.lower()}> {prompt} " # Tokenize inputs = tokenizer(tagged_prompt, return_tensors="pt", padding=True, truncation=True).to(device) print(f"Tagged Prompt: {tagged_prompt}") print(f"Inputs: {inputs}") print(f"Input IDs shape: {inputs['input_ids'].shape}") print(f"Attention Mask shape: {inputs['attention_mask'].shape}") # Generate try: outputs = model.generate( **inputs, max_length=generation_config.max_length, num_beams=generation_config.num_beams, do_sample=generation_config.do_sample, temperature=generation_config.temperature, top_k=generation_config.top_k, top_p=generation_config.top_p, repetition_penalty=generation_config.repetition_penalty, length_penalty=generation_config.length_penalty, early_stopping=generation_config.early_stopping ) # Decode and return return tokenizer.decode(outputs[0], skip_special_tokens=True) except Exception as e: print(f"Error during generation: {e}") return "An error occurred during text generation." # Create Gradio interface iface = gr.Interface( fn=generate_text, inputs=[ gr.Textbox(label="Enter your prompt"), gr.Dropdown(choices=["yor", "ibo", "hau", "efi", "pcm", "urh"], label="Select Language") ], outputs=gr.Textbox(label="Generated Text"), title="Nigerian Language Generator", description="Generate text in Yoruba, Igbo, Hausa, Efik, Pidgin, or Urhobo using the Sabi Yarn model." ) if __name__ == "__main__": iface.launch(share=True) # Set share=True to create a public link