SudiptoPramanik's picture
Update app.py
6899674 verified
import gradio as gr
import os
from huggingface_hub import login
hf_token = os.getenv("HUGGINGFACE_HUB_TOKEN")
if not hf_token:
raise ValueError(
"HUGGINGFACE_HUB_TOKEN is not set. Please set it as an environment variable or provide it programmatically."
)
login(hf_token)
import torch
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import transformers
import time
from transformers import TrainingArguments, Trainer
import torch
from peft import PeftModel, PeftConfig
peft_model_base = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype=torch.bfloat16)
Gen_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
Gen_tokenizer.pad_token = Gen_tokenizer.eos_token
Gen_model = PeftModel.from_pretrained(peft_model_base,
'SudiptoPramanik/Mistral_FineTunedModel_for_Non-spam_Mail_Generation',
torch_dtype=torch.bfloat16,
is_trainable=False)
# Move the model to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
Gen_model.to(device)
# Inference function
def Gen_make_inference(Question):
Prompt = "Read the following Question carefully and then generate an appropriate answer to this question correctly."
Grade = "not spam"
input_template = f"prompt:{Prompt}\n\n ### Input(Grade):{Grade}\n ### Input(Question):{Question}\n\n### Answer:"
# Tokenize input
input_tokens = Gen_tokenizer(
input_template,
return_tensors='pt',
max_length=220,
truncation=True,
padding=True
).to(device)
# Generate output
with torch.cuda.amp.autocast():
output_tokens = Gen_model.generate(
**input_tokens,
max_length=250,
num_return_sequences=1,
no_repeat_ngram_size=2,
top_k=50, # Changed to 50 for more meaningful diversity
top_p=0.95, # Nucleus sampling
eos_token_id=Gen_tokenizer.eos_token_id,
pad_token_id=Gen_tokenizer.eos_token_id
)
# Decode the response
Gen_answer = Gen_tokenizer.decode(output_tokens[0], skip_special_tokens=True)
# Extract the answer
try:
text = Gen_answer.split("Answer:")[1].strip()
# Handle case where '*' may not exist
if '*' in text:
text = text.split('*')[0].strip()
return text
except IndexError:
return "Error: Unable to parse the model's response."
with gr.Blocks() as demo:
question = gr.Textbox(label="Mail Prompt")
gen_btn = gr.Button("Non Spam Mail Generation")
answer_gen = gr.Textbox(label="Desired Mail Response")
gen_btn.click(fn=Gen_make_inference, inputs=question, outputs=[answer_gen], api_name="Generator")
if __name__ == "__main__":
demo.launch(share=True,inline=False)