hysts HF staff commited on
Commit
aebdcd2
1 Parent(s): ab59a6c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -13
app.py CHANGED
@@ -1,12 +1,9 @@
1
  import gradio as gr
2
  import os
3
  import spaces
4
- from transformers import GemmaTokenizer, AutoModelForCausalLM
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
  from threading import Thread
7
 
8
- # Set an environment variable
9
- HF_TOKEN = os.environ.get("HF_TOKEN", None)
10
 
11
  TITLE = '''
12
  <h1 style="text-align: center;">Meta Llama3.1 8B <a href="https://huggingface.co/spaces/ysharma/Chat_with_Meta_llama3_1_8b?duplicate=true" id="duplicate-button"><button style="color:white">Duplicate this Space</button></a></h1>
@@ -47,16 +44,18 @@ h1 {
47
  }
48
  """
49
 
50
- model = "llhf/Meta-Llama-3.1-8B-Instruct"
51
 
52
  # Load the tokenizer and model
53
- tokenizer = AutoTokenizer.from_pretrained(f"{model}")
54
- model = AutoModelForCausalLM.from_pretrained(f"{model}", device_map="auto")
55
  terminators = [
56
  tokenizer.eos_token_id,
57
  tokenizer.convert_tokens_to_ids("<|eot_id|>")
58
  ]
59
 
 
 
60
  # Gradio inference function
61
  @spaces.GPU(duration=120)
62
  def chat_llama3_1_8b(message: str,
@@ -79,7 +78,11 @@ def chat_llama3_1_8b(message: str,
79
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
80
  conversation.append({"role": "user", "content": message})
81
 
82
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
 
 
 
 
83
 
84
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
85
 
@@ -87,14 +90,11 @@ def chat_llama3_1_8b(message: str,
87
  input_ids= input_ids,
88
  streamer=streamer,
89
  max_new_tokens=max_new_tokens,
90
- do_sample=True,
91
  temperature=temperature,
92
  eos_token_id=terminators,
93
  )
94
- # This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
95
- if temperature == 0:
96
- generate_kwargs['do_sample'] = False
97
-
98
  t = Thread(target=model.generate, kwargs=generate_kwargs)
99
  t.start()
100
 
@@ -148,4 +148,3 @@ with gr.Blocks(fill_height=True, css=css) as demo:
148
 
149
  if __name__ == "__main__":
150
  demo.launch()
151
-
 
1
  import gradio as gr
2
  import os
3
  import spaces
 
4
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
5
  from threading import Thread
6
 
 
 
7
 
8
  TITLE = '''
9
  <h1 style="text-align: center;">Meta Llama3.1 8B <a href="https://huggingface.co/spaces/ysharma/Chat_with_Meta_llama3_1_8b?duplicate=true" id="duplicate-button"><button style="color:white">Duplicate this Space</button></a></h1>
 
44
  }
45
  """
46
 
47
+ model_id = "llhf/Meta-Llama-3.1-8B-Instruct"
48
 
49
  # Load the tokenizer and model
50
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
51
+ model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
52
  terminators = [
53
  tokenizer.eos_token_id,
54
  tokenizer.convert_tokens_to_ids("<|eot_id|>")
55
  ]
56
 
57
+ MAX_INPUT_TOKEN_LENGTH = 4096
58
+
59
  # Gradio inference function
60
  @spaces.GPU(duration=120)
61
  def chat_llama3_1_8b(message: str,
 
78
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
79
  conversation.append({"role": "user", "content": message})
80
 
81
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
82
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
83
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
84
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
85
+ input_ids = input_ids.to(model.device)
86
 
87
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
88
 
 
90
  input_ids= input_ids,
91
  streamer=streamer,
92
  max_new_tokens=max_new_tokens,
93
+ do_sample=temperature != 0, # This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
94
  temperature=temperature,
95
  eos_token_id=terminators,
96
  )
97
+
 
 
 
98
  t = Thread(target=model.generate, kwargs=generate_kwargs)
99
  t.start()
100
 
 
148
 
149
  if __name__ == "__main__":
150
  demo.launch()