mrcuddle commited on
Commit
0c92bcf
·
verified ·
1 Parent(s): 2990998

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -4,10 +4,11 @@ import torch
4
  import spaces
5
 
6
  # Load the model and tokenizer
7
- model_name = "mrcuddle/DarkHermes3-Llama3.2-3B-Instruct"
8
  device = "cuda" if torch.cuda.is_available() else "cpu" # Detect GPU or default to CPU
 
9
  tokenizer = AutoTokenizer.from_pretrained(model_name)
10
- model = AutoModelForCausalLM.from_pretrained(model_name).to(device) # Move model to the appropriate device
11
  model.eval() # Ensure the model is in evaluation mode
12
 
13
  # Define the system prompt
@@ -30,8 +31,8 @@ def chatbot(message, history):
30
  conversation += "".join([f"User: {msg}\nBot: {resp}\n" for msg, resp in history])
31
  conversation += f"User: {message}\nBot:"
32
 
33
- # Tokenize the input and move it to the correct device
34
- inputs = tokenizer(conversation, return_tensors="pt", truncation=True, max_length=1024).to(device)
35
 
36
  # Generate a response
37
  outputs = model.generate(
 
4
  import spaces
5
 
6
  # Load the model and tokenizer
7
+ model_name = "mrcuddle/Dark-Hermes3-Llama3.2-3B"
8
  device = "cuda" if torch.cuda.is_available() else "cpu" # Detect GPU or default to CPU
9
+ dtype = torch.bfloat16 if device == "cuda" else torch.float32 # Use bfloat16 for mixed precision on GPU
10
  tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=dtype).to(device) # Ensure model is on the correct device
12
  model.eval() # Ensure the model is in evaluation mode
13
 
14
  # Define the system prompt
 
31
  conversation += "".join([f"User: {msg}\nBot: {resp}\n" for msg, resp in history])
32
  conversation += f"User: {message}\nBot:"
33
 
34
+ # Tokenize the input and move it to the correct device and dtype
35
+ inputs = tokenizer(conversation, return_tensors="pt", truncation=True, max_length=1024).to(device, dtype=dtype)
36
 
37
  # Generate a response
38
  outputs = model.generate(