SudiptoPramanik commited on
Commit
de33b14
·
verified ·
1 Parent(s): c60036d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -0
app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import torch
4
+ from peft import PeftModel, PeftConfig
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+
7
+ from datasets import load_dataset
8
+ import transformers
9
+ import time
10
+ from transformers import TrainingArguments, Trainer
11
+ import torch
12
+
13
+ from peft import PeftModel, PeftConfig
14
+ peft_model_base = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype=torch.bfloat16)
15
+ Gen_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
16
+ Gen_tokenizer.pad_token = Gen_tokenizer.eos_token
17
+
18
+ Gen_model = PeftModel.from_pretrained(peft_model_base,
19
+ 'SudiptoPramanik/Mistral_FineTunedModel_for_Non-spam_Mail_Generation',
20
+ torch_dtype=torch.bfloat16,
21
+ is_trainable=False)
22
+
23
+ # Move the model to GPU if available
24
+ device = "cuda" if torch.cuda.is_available() else "cpu"
25
+ Gen_model.to(device)
26
+
27
+ # Inference function
28
+ def Gen_make_inference(Question):
29
+ Prompt = "Read the following Question carefully and then generate an appropriate answer to this question correctly."
30
+ Grade = "non spam"
31
+ input_template = f"prompt:{Prompt}\n\n ### Input(Grade):{Grade}\n ### Input(Question):{Question}\n\n### Answer:"
32
+
33
+ # Tokenize input
34
+ input_tokens = Gen_tokenizer(
35
+ input_template,
36
+ return_tensors='pt',
37
+ max_length=220,
38
+ truncation=True,
39
+ padding=True
40
+ ).to(device)
41
+
42
+ # Generate output
43
+ with torch.cuda.amp.autocast():
44
+ output_tokens = Gen_model.generate(
45
+ **input_tokens,
46
+ max_length=250,
47
+ num_return_sequences=1,
48
+ no_repeat_ngram_size=2,
49
+ top_k=50, # Changed to 50 for more meaningful diversity
50
+ top_p=0.95, # Nucleus sampling
51
+ eos_token_id=Gen_tokenizer.eos_token_id,
52
+ pad_token_id=Gen_tokenizer.eos_token_id
53
+ )
54
+
55
+ # Decode the response
56
+ Gen_answer = Gen_tokenizer.decode(output_tokens[0], skip_special_tokens=True)
57
+
58
+ # Extract the answer
59
+ try:
60
+ text = Gen_answer.split("Answer:")[1].strip()
61
+ # Handle case where '*' may not exist
62
+ if '*' in text:
63
+ text = text.split('*')[0].strip()
64
+ return text
65
+ except IndexError:
66
+ return "Error: Unable to parse the model's response."
67
+
68
+ with gr.Blocks() as demo:
69
+ question = gr.Textbox(label="Mail Prompt")
70
+ gen_btn = gr.Button("Non Spam Mail Generation")
71
+ answer_gen = gr.Textbox(label="Desired Mail Response")
72
+ gen_btn.click(fn=Gen_make_inference, inputs=question, outputs=[answer_gen], api_name="Generator")
73
+
74
+ if __name__ == "__main__":
75
+ demo.launch(inline=False)