|
import gradio as gr |
|
|
|
import os |
|
from huggingface_hub import login |
|
hf_token = os.getenv("HUGGINGFACE_HUB_TOKEN") |
|
|
|
if not hf_token: |
|
raise ValueError( |
|
"HUGGINGFACE_HUB_TOKEN is not set. Please set it as an environment variable or provide it programmatically." |
|
) |
|
|
|
login(hf_token) |
|
|
|
import torch |
|
from peft import PeftModel, PeftConfig |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
from datasets import load_dataset |
|
import transformers |
|
import time |
|
from transformers import TrainingArguments, Trainer |
|
import torch |
|
|
|
from peft import PeftModel, PeftConfig |
|
peft_model_base = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype=torch.bfloat16) |
|
Gen_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") |
|
Gen_tokenizer.pad_token = Gen_tokenizer.eos_token |
|
|
|
Gen_model = PeftModel.from_pretrained(peft_model_base, |
|
'SudiptoPramanik/Mistral_FineTunedModel_for_Non-spam_Mail_Generation', |
|
torch_dtype=torch.bfloat16, |
|
is_trainable=False) |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
Gen_model.to(device) |
|
|
|
|
|
def Gen_make_inference(Question): |
|
Prompt = "Read the following Question carefully and then generate an appropriate answer to this question correctly." |
|
Grade = "not spam" |
|
input_template = f"prompt:{Prompt}\n\n ### Input(Grade):{Grade}\n ### Input(Question):{Question}\n\n### Answer:" |
|
|
|
|
|
input_tokens = Gen_tokenizer( |
|
input_template, |
|
return_tensors='pt', |
|
max_length=220, |
|
truncation=True, |
|
padding=True |
|
).to(device) |
|
|
|
|
|
with torch.cuda.amp.autocast(): |
|
output_tokens = Gen_model.generate( |
|
**input_tokens, |
|
max_length=250, |
|
num_return_sequences=1, |
|
no_repeat_ngram_size=2, |
|
top_k=50, |
|
top_p=0.95, |
|
eos_token_id=Gen_tokenizer.eos_token_id, |
|
pad_token_id=Gen_tokenizer.eos_token_id |
|
) |
|
|
|
|
|
Gen_answer = Gen_tokenizer.decode(output_tokens[0], skip_special_tokens=True) |
|
|
|
|
|
try: |
|
text = Gen_answer.split("Answer:")[1].strip() |
|
|
|
if '*' in text: |
|
text = text.split('*')[0].strip() |
|
return text |
|
except IndexError: |
|
return "Error: Unable to parse the model's response." |
|
|
|
with gr.Blocks() as demo: |
|
question = gr.Textbox(label="Mail Prompt") |
|
gen_btn = gr.Button("Non Spam Mail Generation") |
|
answer_gen = gr.Textbox(label="Desired Mail Response") |
|
gen_btn.click(fn=Gen_make_inference, inputs=question, outputs=[answer_gen], api_name="Generator") |
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True,inline=False) |