AtlasUI / app.py
Spestly's picture
Update app.py
b1e24d2 verified
raw
history blame
7.21 kB
import gc
import torch
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
import re
import os
# Load Hugging Face token
HF_TOKEN = os.getenv("HF_TOKEN")
login(token=HF_TOKEN)
# Define models
MODELS = {
"athena-1": {
"name": "🦁 Atlas-Flash",
"sizes": {
"1.5B": "Spestly/Atlas-R1-1.5B-Preview",
},
"emoji": "🦁",
"experimental": True,
},
}
# Profile pictures
USER_PFP = "https://huggingface.co/front/assets/avatars.png" # Hugging Face user avatar
AI_PFP = "ai_pfp.png" # Replace with the path to your AI's image or a URL
class AtlasInferenceApp:
def __init__(self):
if "current_model" not in st.session_state:
st.session_state.current_model = {"tokenizer": None, "model": None, "config": None}
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
st.set_page_config(
page_title="Atlas Model Inference",
page_icon="🦁 ",
layout="wide",
menu_items={
'Get Help': 'https://huggingface.co/collections/Spestly/athena-1-67623e58bfaadd3c2fcffb86',
'Report a bug': 'https://huggingface.co/Spestly/Athena-1-1.5B/discussions/new',
'About': 'Athena Model Inference Platform'
}
)
def clear_memory(self):
"""Optimize memory management for CPU inference"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
def load_model(self, model_key, model_size):
try:
self.clear_memory()
if st.session_state.current_model["model"] is not None:
del st.session_state.current_model["model"]
del st.session_state.current_model["tokenizer"]
self.clear_memory()
model_path = MODELS[model_key]["sizes"][model_size]
# Load Qwen-compatible tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="cpu", # Force CPU usage
torch_dtype=torch.float32, # Use float32 for CPU
trust_remote_code=True,
low_cpu_mem_usage=True
)
# Update session state
st.session_state.current_model.update({
"tokenizer": tokenizer,
"model": model,
"config": {
"name": f"{MODELS[model_key]['name']} {model_size}",
"path": model_path,
}
})
return f"βœ… {MODELS[model_key]['name']} {model_size} loaded successfully!"
except Exception as e:
return f"❌ Error: {str(e)}"
def respond(self, message, max_tokens, temperature, top_p, top_k):
if not st.session_state.current_model["model"]:
return "⚠️ Please select and load a model first"
try:
# Add a system instruction to guide the model's behavior
system_instruction = "You are Atlas, a helpful AI assistant trained to help the user. You are a Deepseek R1 fine-tune."
prompt = f"{system_instruction}\n\n### Instruction:\n{message}\n\n### Response:"
inputs = st.session_state.current_model["tokenizer"](
prompt,
return_tensors="pt",
max_length=512,
truncation=True,
padding=True
)
with torch.no_grad():
output = st.session_state.current_model["model"].generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
do_sample=True,
pad_token_id=st.session_state.current_model["tokenizer"].pad_token_id,
eos_token_id=st.session_state.current_model["tokenizer"].eos_token_id,
)
response = st.session_state.current_model["tokenizer"].decode(output[0], skip_special_tokens=True)
return response.split("### Response:")[-1].strip() # Extract the response
except Exception as e:
return f"⚠️ Generation Error: {str(e)}"
finally:
self.clear_memory()
def main(self):
st.title("🦁 AtlasUI - Experimental πŸ§ͺ")
with st.sidebar:
st.header("πŸ›  Model Selection")
model_key = st.selectbox(
"Choose Atlas Variant",
list(MODELS.keys()),
format_func=lambda x: f"{MODELS[x]['name']} {'πŸ§ͺ' if MODELS[x]['experimental'] else ''}"
)
model_size = st.selectbox(
"Choose Model Size",
list(MODELS[model_key]["sizes"].keys())
)
if st.button("Load Model"):
with st.spinner("Loading model... This may take a few minutes."):
status = self.load_model(model_key, model_size)
st.success(status)
st.header("πŸ”§ Generation Parameters")
max_tokens = st.slider("Max New Tokens", min_value=10, max_value=512, value=256, step=10)
temperature = st.slider("Temperature", min_value=0.1, max_value=2.0, value=0.4, step=0.1)
top_p = st.slider("Top-P", min_value=0.1, max_value=1.0, value=0.9, step=0.1)
top_k = st.slider("Top-K", min_value=1, max_value=100, value=50, step=1)
if st.button("Clear Chat History"):
st.session_state.chat_history = []
st.rerun()
st.markdown("*⚠️ CAUTION: Atlas is an experimental model and this is just a preview. Responses may not be expected. Please double-check sensitive information!*")
# Display chat history
for message in st.session_state.chat_history:
with st.chat_message(
message["role"],
avatar=USER_PFP if message["role"] == "user" else AI_PFP
):
st.markdown(message["content"])
# Input box for user messages
if prompt := st.chat_input("Message Atlas..."):
st.session_state.chat_history.append({"role": "user", "content": prompt})
with st.chat_message("user", avatar=USER_PFP):
st.markdown(prompt)
with st.chat_message("assistant", avatar=AI_PFP):
with st.spinner("Generating response..."):
response = self.respond(prompt, max_tokens, temperature, top_p, top_k)
st.markdown(response)
st.session_state.chat_history.append({"role": "assistant", "content": response})
def run():
try:
app = AtlasInferenceApp()
app.main()
except Exception as e:
st.error(f"⚠️ Application Error: {str(e)}")
if __name__ == "__main__":
run()