sbicy commited on
Commit
0922636
·
verified ·
1 Parent(s): d3fb7ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -24
app.py CHANGED
@@ -3,42 +3,63 @@ import os
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import gradio as gr
5
 
6
- # Load the model and tokenizer
7
- model_name = "distilgpt2"
8
- tokenizer = AutoTokenizer.from_pretrained(model_name)
9
- model = AutoModelForCausalLM.from_pretrained(model_name)
10
-
11
- # Define the function to generate a response
12
- def generate_response(prompt):
13
- inputs = tokenizer(prompt, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
14
  outputs = model.generate(
15
  inputs.input_ids,
16
- max_length=70,
17
  do_sample=True,
18
- temperature=0.6,
19
- top_p=0.9,
20
- repetition_penalty=1.2,
21
  pad_token_id=tokenizer.eos_token_id
22
  )
23
  response = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
24
- return response
 
 
 
 
 
25
 
26
- # Persona-based response function
27
- def persona_response(prompt, persona="You are a helpful talking dog that answers in short, simple phrases."):
28
- full_prompt = f"{persona}: {prompt}"
29
- return generate_response(full_prompt)
30
 
31
- # Define Gradio interface function
32
- def chat_interface(user_input, persona="You are a helpful talking dog that answers in short, simple phrases."):
33
- return persona_response(user_input, persona)
34
 
35
- # Gradio interface setup
36
  interface = gr.Interface(
37
  fn=chat_interface,
38
- inputs=["text", "text"],
 
 
 
 
 
 
 
 
39
  outputs="text",
40
- title="Simple Chatbot",
41
- description="Chat with the bot! Add a persona like 'I am a shopping assistant.'"
42
  )
43
 
44
  # Launch the Gradio app
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import gradio as gr
5
 
6
+ # Function to load model and tokenizer based on selection
7
+ def load_model(model_name):
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ model = AutoModelForCausalLM.from_pretrained(model_name)
10
+ return tokenizer, model
11
+
12
+ # Define the function to generate a response with adjustable parameters and model-specific adjustments
13
+ def generate_response(prompt, model_name, persona="I am a helpful assistant.", temperature=0.7, top_p=0.9, repetition_penalty=1.2, max_length=70):
14
+ # Load the chosen model and tokenizer
15
+ tokenizer, model = load_model(model_name)
16
+
17
+ # Adjust the prompt format for DialoGPT
18
+ if model_name == "microsoft/DialoGPT-small":
19
+ full_prompt = f"User: {prompt}\nBot:" # Structure as a conversation
20
+ else:
21
+ full_prompt = f"{persona}: {prompt}" # Standard format for other models
22
+
23
+ # Tokenize and generate response
24
+ inputs = tokenizer(full_prompt, return_tensors="pt")
25
  outputs = model.generate(
26
  inputs.input_ids,
27
+ max_length=max_length,
28
  do_sample=True,
29
+ temperature=temperature,
30
+ top_p=top_p,
31
+ repetition_penalty=repetition_penalty,
32
  pad_token_id=tokenizer.eos_token_id
33
  )
34
  response = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
35
+
36
+ # Trim the prompt if it appears in the response
37
+ if model_name == "microsoft/DialoGPT-small":
38
+ response_without_prompt = response.split("Bot:", 1)[-1].strip()
39
+ else:
40
+ response_without_prompt = response.split(":", 1)[-1].strip()
41
 
42
+ return response_without_prompt if response_without_prompt else "I'm not sure how to respond to that."
 
 
 
43
 
44
+ # Define Gradio interface function with model selection
45
+ def chat_interface(user_input, model_choice, persona="I am a helpful assistant", temperature=0.7, top_p=0.9, repetition_penalty=1.2, max_length=50):
46
+ return generate_response(user_input, model_choice, persona, temperature, top_p, repetition_penalty, max_length)
47
 
48
+ # Set up Gradio interface with model selection and parameter sliders
49
  interface = gr.Interface(
50
  fn=chat_interface,
51
+ inputs=[
52
+ gr.Textbox(label="User Input"),
53
+ gr.Dropdown(choices=["distilgpt2", "gpt2", "microsoft/DialoGPT-small"], label="Model Choice", value="distilgpt2"),
54
+ gr.Textbox(label="Persona", value="You are a helpful assistant."),
55
+ gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.7, step=0.1),
56
+ gr.Slider(label="Top-p (Nucleus Sampling)", minimum=0.1, maximum=1.0, value=0.9, step=0.1),
57
+ gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.2, step=0.1),
58
+ gr.Slider(label="Max Length", minimum=10, maximum=100, value=50, step=5)
59
+ ],
60
  outputs="text",
61
+ title="Interactive Chatbot with Model Comparison",
62
+ description="Chat with the bot! Select a model and adjust parameters to see how they affect the response."
63
  )
64
 
65
  # Launch the Gradio app