Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -5,6 +5,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
5 |
from huggingface_hub import login
|
6 |
import re
|
7 |
import os
|
|
|
8 |
|
9 |
# Load Hugging Face token
|
10 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
@@ -19,6 +20,7 @@ MODELS = {
|
|
19 |
},
|
20 |
"emoji": "π¦",
|
21 |
"experimental": True,
|
|
|
22 |
},
|
23 |
}
|
24 |
|
@@ -84,7 +86,7 @@ class AtlasInferenceApp:
|
|
84 |
except Exception as e:
|
85 |
return f"β Error: {str(e)}"
|
86 |
|
87 |
-
def respond(self, message, max_tokens, temperature, top_p, top_k):
|
88 |
if not st.session_state.current_model["model"]:
|
89 |
return "β οΈ Please select and load a model first"
|
90 |
|
@@ -104,6 +106,7 @@ class AtlasInferenceApp:
|
|
104 |
# Generate response with streaming
|
105 |
response_container = st.empty() # Placeholder for streaming text
|
106 |
full_response = ""
|
|
|
107 |
with torch.no_grad():
|
108 |
for chunk in st.session_state.current_model["model"].generate(
|
109 |
input_ids=inputs.input_ids,
|
@@ -116,19 +119,18 @@ class AtlasInferenceApp:
|
|
116 |
pad_token_id=st.session_state.current_model["tokenizer"].pad_token_id,
|
117 |
eos_token_id=st.session_state.current_model["tokenizer"].eos_token_id,
|
118 |
):
|
119 |
-
# Decode the
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
break
|
132 |
|
133 |
# Stop if the response is too long or incomplete
|
134 |
if len(full_response) >= max_tokens * 4: # Approximate token-to-character ratio
|
@@ -182,16 +184,25 @@ class AtlasInferenceApp:
|
|
182 |
avatar=USER_PFP if message["role"] == "user" else AI_PFP
|
183 |
):
|
184 |
st.markdown(message["content"])
|
|
|
|
|
185 |
|
186 |
# Input box for user messages
|
187 |
if prompt := st.chat_input("Message Atlas..."):
|
188 |
-
|
|
|
|
|
|
|
|
|
|
|
189 |
with st.chat_message("user", avatar=USER_PFP):
|
190 |
st.markdown(prompt)
|
|
|
|
|
191 |
|
192 |
with st.chat_message("assistant", avatar=AI_PFP):
|
193 |
with st.spinner("Generating response..."):
|
194 |
-
response = self.respond(prompt, max_tokens, temperature, top_p, top_k)
|
195 |
st.markdown(response)
|
196 |
|
197 |
st.session_state.chat_history.append({"role": "assistant", "content": response})
|
|
|
5 |
from huggingface_hub import login
|
6 |
import re
|
7 |
import os
|
8 |
+
from PIL import Image
|
9 |
|
10 |
# Load Hugging Face token
|
11 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
|
|
20 |
},
|
21 |
"emoji": "π¦",
|
22 |
"experimental": True,
|
23 |
+
"is_vision": False, # Enable vision support for this model
|
24 |
},
|
25 |
}
|
26 |
|
|
|
86 |
except Exception as e:
|
87 |
return f"β Error: {str(e)}"
|
88 |
|
89 |
+
def respond(self, message, max_tokens, temperature, top_p, top_k, image=None):
|
90 |
if not st.session_state.current_model["model"]:
|
91 |
return "β οΈ Please select and load a model first"
|
92 |
|
|
|
106 |
# Generate response with streaming
|
107 |
response_container = st.empty() # Placeholder for streaming text
|
108 |
full_response = ""
|
109 |
+
generated_tokens = [] # Track generated tokens to avoid duplicates
|
110 |
with torch.no_grad():
|
111 |
for chunk in st.session_state.current_model["model"].generate(
|
112 |
input_ids=inputs.input_ids,
|
|
|
119 |
pad_token_id=st.session_state.current_model["tokenizer"].pad_token_id,
|
120 |
eos_token_id=st.session_state.current_model["tokenizer"].eos_token_id,
|
121 |
):
|
122 |
+
# Decode only the new tokens
|
123 |
+
new_tokens = chunk[:, inputs.input_ids.shape[1]:] # Exclude input tokens
|
124 |
+
generated_tokens.extend(new_tokens[0].tolist()) # Add new tokens to the list
|
125 |
+
chunk_text = st.session_state.current_model["tokenizer"].decode(generated_tokens, skip_special_tokens=True)
|
126 |
+
|
127 |
+
# Remove the prompt from the response
|
128 |
+
if prompt in chunk_text:
|
129 |
+
chunk_text = chunk_text.replace(prompt, "").strip()
|
130 |
+
|
131 |
+
# Update the response
|
132 |
+
full_response = chunk_text
|
133 |
+
response_container.markdown(full_response)
|
|
|
134 |
|
135 |
# Stop if the response is too long or incomplete
|
136 |
if len(full_response) >= max_tokens * 4: # Approximate token-to-character ratio
|
|
|
184 |
avatar=USER_PFP if message["role"] == "user" else AI_PFP
|
185 |
):
|
186 |
st.markdown(message["content"])
|
187 |
+
if "image" in message:
|
188 |
+
st.image(message["image"], caption="Uploaded Image", use_column_width=True)
|
189 |
|
190 |
# Input box for user messages
|
191 |
if prompt := st.chat_input("Message Atlas..."):
|
192 |
+
# Allow image upload if the model supports vision
|
193 |
+
uploaded_image = None
|
194 |
+
if MODELS[model_key]["is_vision"]:
|
195 |
+
uploaded_image = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
|
196 |
+
|
197 |
+
st.session_state.chat_history.append({"role": "user", "content": prompt, "image": uploaded_image})
|
198 |
with st.chat_message("user", avatar=USER_PFP):
|
199 |
st.markdown(prompt)
|
200 |
+
if uploaded_image:
|
201 |
+
st.image(uploaded_image, caption="Uploaded Image", use_column_width=True)
|
202 |
|
203 |
with st.chat_message("assistant", avatar=AI_PFP):
|
204 |
with st.spinner("Generating response..."):
|
205 |
+
response = self.respond(prompt, max_tokens, temperature, top_p, top_k, image=uploaded_image)
|
206 |
st.markdown(response)
|
207 |
|
208 |
st.session_state.chat_history.append({"role": "assistant", "content": response})
|