import gradio as gr import os from huggingface_hub import login hf_token = os.getenv("HUGGINGFACE_HUB_TOKEN") if not hf_token: raise ValueError( "HUGGINGFACE_HUB_TOKEN is not set. Please set it as an environment variable or provide it programmatically." ) login(hf_token) import torch from peft import PeftModel, PeftConfig from transformers import AutoModelForCausalLM, AutoTokenizer from datasets import load_dataset import transformers import time from transformers import TrainingArguments, Trainer import torch from peft import PeftModel, PeftConfig peft_model_base = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype=torch.bfloat16) Gen_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") Gen_tokenizer.pad_token = Gen_tokenizer.eos_token Gen_model = PeftModel.from_pretrained(peft_model_base, 'SudiptoPramanik/Mistral_FineTunedModel_for_Non-spam_Mail_Generation', torch_dtype=torch.bfloat16, is_trainable=False) # Move the model to GPU if available device = "cuda" if torch.cuda.is_available() else "cpu" Gen_model.to(device) # Inference function def Gen_make_inference(Question): Prompt = "Read the following Question carefully and then generate an appropriate answer to this question correctly." Grade = "not spam" input_template = f"prompt:{Prompt}\n\n ### Input(Grade):{Grade}\n ### Input(Question):{Question}\n\n### Answer:" # Tokenize input input_tokens = Gen_tokenizer( input_template, return_tensors='pt', max_length=220, truncation=True, padding=True ).to(device) # Generate output with torch.cuda.amp.autocast(): output_tokens = Gen_model.generate( **input_tokens, max_length=250, num_return_sequences=1, no_repeat_ngram_size=2, top_k=50, # Changed to 50 for more meaningful diversity top_p=0.95, # Nucleus sampling eos_token_id=Gen_tokenizer.eos_token_id, pad_token_id=Gen_tokenizer.eos_token_id ) # Decode the response Gen_answer = Gen_tokenizer.decode(output_tokens[0], skip_special_tokens=True) # Extract the answer try: text = Gen_answer.split("Answer:")[1].strip() # Handle case where '*' may not exist if '*' in text: text = text.split('*')[0].strip() return text except IndexError: return "Error: Unable to parse the model's response." with gr.Blocks() as demo: question = gr.Textbox(label="Mail Prompt") gen_btn = gr.Button("Non Spam Mail Generation") answer_gen = gr.Textbox(label="Desired Mail Response") gen_btn.click(fn=Gen_make_inference, inputs=question, outputs=[answer_gen], api_name="Generator") if __name__ == "__main__": demo.launch(share=True,inline=False)