Spestly commited on
Commit
7c8d3e4
Β·
verified Β·
1 Parent(s): 8653d1f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -16
app.py CHANGED
@@ -5,6 +5,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
5
  from huggingface_hub import login
6
  import re
7
  import os
 
8
 
9
  # Load Hugging Face token
10
  HF_TOKEN = os.getenv("HF_TOKEN")
@@ -19,6 +20,7 @@ MODELS = {
19
  },
20
  "emoji": "🦁",
21
  "experimental": True,
 
22
  },
23
  }
24
 
@@ -84,7 +86,7 @@ class AtlasInferenceApp:
84
  except Exception as e:
85
  return f"❌ Error: {str(e)}"
86
 
87
- def respond(self, message, max_tokens, temperature, top_p, top_k):
88
  if not st.session_state.current_model["model"]:
89
  return "⚠️ Please select and load a model first"
90
 
@@ -104,6 +106,7 @@ class AtlasInferenceApp:
104
  # Generate response with streaming
105
  response_container = st.empty() # Placeholder for streaming text
106
  full_response = ""
 
107
  with torch.no_grad():
108
  for chunk in st.session_state.current_model["model"].generate(
109
  input_ids=inputs.input_ids,
@@ -116,19 +119,18 @@ class AtlasInferenceApp:
116
  pad_token_id=st.session_state.current_model["tokenizer"].pad_token_id,
117
  eos_token_id=st.session_state.current_model["tokenizer"].eos_token_id,
118
  ):
119
- # Decode the chunk and update the response
120
- try:
121
- chunk_text = st.session_state.current_model["tokenizer"].decode(chunk, skip_special_tokens=True)
122
-
123
- # Remove the prompt from the response
124
- if prompt in chunk_text:
125
- chunk_text = chunk_text.replace(prompt, "").strip()
126
-
127
- full_response += chunk_text
128
- response_container.markdown(full_response)
129
- except Exception as decode_error:
130
- st.error(f"⚠️ Token Decoding Error: {str(decode_error)}")
131
- break
132
 
133
  # Stop if the response is too long or incomplete
134
  if len(full_response) >= max_tokens * 4: # Approximate token-to-character ratio
@@ -182,16 +184,25 @@ class AtlasInferenceApp:
182
  avatar=USER_PFP if message["role"] == "user" else AI_PFP
183
  ):
184
  st.markdown(message["content"])
 
 
185
 
186
  # Input box for user messages
187
  if prompt := st.chat_input("Message Atlas..."):
188
- st.session_state.chat_history.append({"role": "user", "content": prompt})
 
 
 
 
 
189
  with st.chat_message("user", avatar=USER_PFP):
190
  st.markdown(prompt)
 
 
191
 
192
  with st.chat_message("assistant", avatar=AI_PFP):
193
  with st.spinner("Generating response..."):
194
- response = self.respond(prompt, max_tokens, temperature, top_p, top_k)
195
  st.markdown(response)
196
 
197
  st.session_state.chat_history.append({"role": "assistant", "content": response})
 
5
  from huggingface_hub import login
6
  import re
7
  import os
8
+ from PIL import Image
9
 
10
  # Load Hugging Face token
11
  HF_TOKEN = os.getenv("HF_TOKEN")
 
20
  },
21
  "emoji": "🦁",
22
  "experimental": True,
23
+ "is_vision": False, # Enable vision support for this model
24
  },
25
  }
26
 
 
86
  except Exception as e:
87
  return f"❌ Error: {str(e)}"
88
 
89
+ def respond(self, message, max_tokens, temperature, top_p, top_k, image=None):
90
  if not st.session_state.current_model["model"]:
91
  return "⚠️ Please select and load a model first"
92
 
 
106
  # Generate response with streaming
107
  response_container = st.empty() # Placeholder for streaming text
108
  full_response = ""
109
+ generated_tokens = [] # Track generated tokens to avoid duplicates
110
  with torch.no_grad():
111
  for chunk in st.session_state.current_model["model"].generate(
112
  input_ids=inputs.input_ids,
 
119
  pad_token_id=st.session_state.current_model["tokenizer"].pad_token_id,
120
  eos_token_id=st.session_state.current_model["tokenizer"].eos_token_id,
121
  ):
122
+ # Decode only the new tokens
123
+ new_tokens = chunk[:, inputs.input_ids.shape[1]:] # Exclude input tokens
124
+ generated_tokens.extend(new_tokens[0].tolist()) # Add new tokens to the list
125
+ chunk_text = st.session_state.current_model["tokenizer"].decode(generated_tokens, skip_special_tokens=True)
126
+
127
+ # Remove the prompt from the response
128
+ if prompt in chunk_text:
129
+ chunk_text = chunk_text.replace(prompt, "").strip()
130
+
131
+ # Update the response
132
+ full_response = chunk_text
133
+ response_container.markdown(full_response)
 
134
 
135
  # Stop if the response is too long or incomplete
136
  if len(full_response) >= max_tokens * 4: # Approximate token-to-character ratio
 
184
  avatar=USER_PFP if message["role"] == "user" else AI_PFP
185
  ):
186
  st.markdown(message["content"])
187
+ if "image" in message:
188
+ st.image(message["image"], caption="Uploaded Image", use_column_width=True)
189
 
190
  # Input box for user messages
191
  if prompt := st.chat_input("Message Atlas..."):
192
+ # Allow image upload if the model supports vision
193
+ uploaded_image = None
194
+ if MODELS[model_key]["is_vision"]:
195
+ uploaded_image = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
196
+
197
+ st.session_state.chat_history.append({"role": "user", "content": prompt, "image": uploaded_image})
198
  with st.chat_message("user", avatar=USER_PFP):
199
  st.markdown(prompt)
200
+ if uploaded_image:
201
+ st.image(uploaded_image, caption="Uploaded Image", use_column_width=True)
202
 
203
  with st.chat_message("assistant", avatar=AI_PFP):
204
  with st.spinner("Generating response..."):
205
+ response = self.respond(prompt, max_tokens, temperature, top_p, top_k, image=uploaded_image)
206
  st.markdown(response)
207
 
208
  st.session_state.chat_history.append({"role": "assistant", "content": response})