Wtzwho commited on
Commit
5e835ad
·
verified ·
1 Parent(s): 4f46af1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -14
app.py CHANGED
@@ -1,38 +1,34 @@
1
  import os
2
  import gradio as gr
3
  from transformers import AutoTokenizer, pipeline
4
- import torch
5
 
6
- # Initialize the model and tokenizer
7
  model_name = "AIFS/Prometh-MOEM-V.01"
 
8
 
9
- HF_TOKEN = os.environ.get("HF_TOKEN")
10
-
11
-
12
- tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=HF_TOKEN)
13
 
14
  text_generation_pipeline = pipeline(
15
  "text-generation",
16
  model=model_name,
17
- model_kwargs={"torch_dtype": torch.float16, "load_in_4bit": True},
18
- use_auth_token=HF_TOKEN
19
  )
20
 
21
-
22
  def generate_text(user_input):
23
  messages = [{"role": "user", "content": user_input}]
24
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
25
  outputs = text_generation_pipeline(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
26
  return outputs[0]["generated_text"]
27
 
28
- # Create the Gradio interface
29
  iface = gr.Interface(
30
  fn=generate_text,
31
- inputs=gr.inputs.Textbox(lines=2, placeholder="Type your question here..."),
32
- outputs=gr.outputs.Textbox(),
33
  title="Prometh-MOEM Text Generation",
34
  description="A text generation model that understands your queries and generates concise, informative responses."
35
  )
36
 
37
- # Run the interface
38
- iface.launch()
 
1
  import os
2
  import gradio as gr
3
  from transformers import AutoTokenizer, pipeline
 
4
 
5
+ # Initialize the model and tokenizer with environment variable for HF_TOKEN
6
  model_name = "AIFS/Prometh-MOEM-V.01"
7
+ hf_token = os.getenv("HF_TOKEN") # More Pythonic way to fetch environment variables
8
 
9
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_token)
 
 
 
10
 
11
  text_generation_pipeline = pipeline(
12
  "text-generation",
13
  model=model_name,
14
+ model_kwargs={"torch_dtype": "auto", "load_in_4bit": True}, # 'auto' lets PyTorch decide the most optimal dtype
15
+ use_auth_token=hf_token
16
  )
17
 
 
18
  def generate_text(user_input):
19
  messages = [{"role": "user", "content": user_input}]
20
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
21
  outputs = text_generation_pipeline(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
22
  return outputs[0]["generated_text"]
23
 
24
+ # Updated Gradio interface creation to use the latest syntax
25
  iface = gr.Interface(
26
  fn=generate_text,
27
+ inputs=gr.Textbox(lines=2, placeholder="Type your question here..."),
28
+ outputs=gr.Textbox(),
29
  title="Prometh-MOEM Text Generation",
30
  description="A text generation model that understands your queries and generates concise, informative responses."
31
  )
32
 
33
+ # Run the interface with enhanced parameters for better performance and user experience
34
+ iface.launch(enable_queue=True) # enable_queue=True for handling high traffic