Spaces:
Sleeping
Sleeping
import gc | |
import torch | |
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from huggingface_hub import login | |
import re | |
import os | |
# Load Hugging Face token | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
login(token=HF_TOKEN) | |
# Define models | |
MODELS = { | |
"athena-1": { | |
"name": "π¦ Atlas-Flash", | |
"sizes": { | |
"1.5B": "Spestly/Atlas-R1-1.5B-Preview", | |
}, | |
"emoji": "π¦", | |
"experimental": True, | |
}, | |
} | |
# Profile pictures | |
USER_PFP = "https://huggingface.co/front/assets/avatars.png" # Hugging Face user avatar | |
AI_PFP = "ai_pfp.png" # Replace with the path to your AI's image or a URL | |
class AtlasInferenceApp: | |
def __init__(self): | |
if "current_model" not in st.session_state: | |
st.session_state.current_model = {"tokenizer": None, "model": None, "config": None} | |
if "chat_history" not in st.session_state: | |
st.session_state.chat_history = [] | |
st.set_page_config( | |
page_title="Atlas Model Inference", | |
page_icon="π¦ ", | |
layout="wide", | |
menu_items={ | |
'Get Help': 'https://huggingface.co/collections/Spestly/athena-1-67623e58bfaadd3c2fcffb86', | |
'Report a bug': 'https://huggingface.co/Spestly/Athena-1-1.5B/discussions/new', | |
'About': 'Athena Model Inference Platform' | |
} | |
) | |
def clear_memory(self): | |
"""Optimize memory management for CPU inference""" | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
def load_model(self, model_key, model_size): | |
try: | |
self.clear_memory() | |
if st.session_state.current_model["model"] is not None: | |
del st.session_state.current_model["model"] | |
del st.session_state.current_model["tokenizer"] | |
self.clear_memory() | |
model_path = MODELS[model_key]["sizes"][model_size] | |
# Load Qwen-compatible tokenizer and model | |
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
device_map="cpu", # Force CPU usage | |
torch_dtype=torch.float32, # Use float32 for CPU | |
trust_remote_code=True, | |
low_cpu_mem_usage=True | |
) | |
# Update session state | |
st.session_state.current_model.update({ | |
"tokenizer": tokenizer, | |
"model": model, | |
"config": { | |
"name": f"{MODELS[model_key]['name']} {model_size}", | |
"path": model_path, | |
} | |
}) | |
return f"β {MODELS[model_key]['name']} {model_size} loaded successfully!" | |
except Exception as e: | |
return f"β Error: {str(e)}" | |
def respond(self, message, max_tokens, temperature, top_p, top_k): | |
if not st.session_state.current_model["model"]: | |
return "β οΈ Please select and load a model first" | |
try: | |
# Add a system instruction to guide the model's behavior | |
system_instruction = "You are Atlas, a helpful AI assistant trained to help the user. You are a Deepseek R1 fine-tune." | |
prompt = f"{system_instruction}\n\n### Instruction:\n{message}\n\n### Response:" | |
inputs = st.session_state.current_model["tokenizer"]( | |
prompt, | |
return_tensors="pt", | |
max_length=512, | |
truncation=True, | |
padding=True | |
) | |
with torch.no_grad(): | |
output = st.session_state.current_model["model"].generate( | |
input_ids=inputs.input_ids, | |
attention_mask=inputs.attention_mask, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
do_sample=True, | |
pad_token_id=st.session_state.current_model["tokenizer"].pad_token_id, | |
eos_token_id=st.session_state.current_model["tokenizer"].eos_token_id, | |
) | |
response = st.session_state.current_model["tokenizer"].decode(output[0], skip_special_tokens=True) | |
return response.split("### Response:")[-1].strip() # Extract the response | |
except Exception as e: | |
return f"β οΈ Generation Error: {str(e)}" | |
finally: | |
self.clear_memory() | |
def main(self): | |
st.title("π¦ AtlasUI - Experimental π§ͺ") | |
with st.sidebar: | |
st.header("π Model Selection") | |
model_key = st.selectbox( | |
"Choose Atlas Variant", | |
list(MODELS.keys()), | |
format_func=lambda x: f"{MODELS[x]['name']} {'π§ͺ' if MODELS[x]['experimental'] else ''}" | |
) | |
model_size = st.selectbox( | |
"Choose Model Size", | |
list(MODELS[model_key]["sizes"].keys()) | |
) | |
if st.button("Load Model"): | |
with st.spinner("Loading model... This may take a few minutes."): | |
status = self.load_model(model_key, model_size) | |
st.success(status) | |
st.header("π§ Generation Parameters") | |
max_tokens = st.slider("Max New Tokens", min_value=10, max_value=512, value=256, step=10) | |
temperature = st.slider("Temperature", min_value=0.1, max_value=2.0, value=0.4, step=0.1) | |
top_p = st.slider("Top-P", min_value=0.1, max_value=1.0, value=0.9, step=0.1) | |
top_k = st.slider("Top-K", min_value=1, max_value=100, value=50, step=1) | |
if st.button("Clear Chat History"): | |
st.session_state.chat_history = [] | |
st.rerun() | |
st.markdown("*β οΈ CAUTION: Atlas is an experimental model and this is just a preview. Responses may not be expected. Please double-check sensitive information!*") | |
# Display chat history | |
for message in st.session_state.chat_history: | |
with st.chat_message( | |
message["role"], | |
avatar=USER_PFP if message["role"] == "user" else AI_PFP | |
): | |
st.markdown(message["content"]) | |
# Input box for user messages | |
if prompt := st.chat_input("Message Atlas..."): | |
st.session_state.chat_history.append({"role": "user", "content": prompt}) | |
with st.chat_message("user", avatar=USER_PFP): | |
st.markdown(prompt) | |
with st.chat_message("assistant", avatar=AI_PFP): | |
with st.spinner("Generating response..."): | |
response = self.respond(prompt, max_tokens, temperature, top_p, top_k) | |
st.markdown(response) | |
st.session_state.chat_history.append({"role": "assistant", "content": response}) | |
def run(): | |
try: | |
app = AtlasInferenceApp() | |
app.main() | |
except Exception as e: | |
st.error(f"β οΈ Application Error: {str(e)}") | |
if __name__ == "__main__": | |
run() |