import gc import torch import streamlit as st from transformers import AutoTokenizer, AutoModelForCausalLM from huggingface_hub import login import re import os # Load Hugging Face token HF_TOKEN = os.getenv("HF_TOKEN") login(token=HF_TOKEN) # Define models MODELS = { "athena-1": { "name": "๐Ÿฆ Atlas-Flash", "sizes": { "1.5B": "Spestly/Atlas-R1-1.5B-Preview", }, "emoji": "๐Ÿฆ", "experimental": True, }, } # Profile pictures USER_PFP = "https://huggingface.co/front/assets/avatars.png" # Hugging Face user avatar AI_PFP = "ai_pfp.png" # Replace with the path to your AI's image or a URL class AtlasInferenceApp: def __init__(self): if "current_model" not in st.session_state: st.session_state.current_model = {"tokenizer": None, "model": None, "config": None} if "chat_history" not in st.session_state: st.session_state.chat_history = [] st.set_page_config( page_title="Atlas Model Inference", page_icon="๐Ÿฆ ", layout="wide", menu_items={ 'Get Help': 'https://huggingface.co/collections/Spestly/athena-1-67623e58bfaadd3c2fcffb86', 'Report a bug': 'https://huggingface.co/Spestly/Athena-1-1.5B/discussions/new', 'About': 'Athena Model Inference Platform' } ) def clear_memory(self): """Optimize memory management for CPU inference""" if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() def load_model(self, model_key, model_size): try: self.clear_memory() if st.session_state.current_model["model"] is not None: del st.session_state.current_model["model"] del st.session_state.current_model["tokenizer"] self.clear_memory() model_path = MODELS[model_key]["sizes"][model_size] # Load Qwen-compatible tokenizer and model tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( model_path, device_map="cpu", # Force CPU usage torch_dtype=torch.float32, # Use float32 for CPU trust_remote_code=True, low_cpu_mem_usage=True ) # Update session state st.session_state.current_model.update({ "tokenizer": tokenizer, "model": model, "config": { "name": f"{MODELS[model_key]['name']} {model_size}", "path": model_path, } }) return f"โœ… {MODELS[model_key]['name']} {model_size} loaded successfully!" except Exception as e: return f"โŒ Error: {str(e)}" def respond(self, message, max_tokens, temperature, top_p, top_k): if not st.session_state.current_model["model"]: return "โš ๏ธ Please select and load a model first" try: # Add a system instruction to guide the model's behavior system_instruction = "You are Atlas, a helpful AI assistant trained to help the user. You are a Deepseek R1 fine-tune." prompt = f"{system_instruction}\n\n### Instruction:\n{message}\n\n### Response:" inputs = st.session_state.current_model["tokenizer"]( prompt, return_tensors="pt", max_length=512, truncation=True, padding=True ) with torch.no_grad(): output = st.session_state.current_model["model"].generate( input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, top_k=top_k, do_sample=True, pad_token_id=st.session_state.current_model["tokenizer"].pad_token_id, eos_token_id=st.session_state.current_model["tokenizer"].eos_token_id, ) response = st.session_state.current_model["tokenizer"].decode(output[0], skip_special_tokens=True) return response.split("### Response:")[-1].strip() # Extract the response except Exception as e: return f"โš ๏ธ Generation Error: {str(e)}" finally: self.clear_memory() def main(self): st.title("๐Ÿฆ AtlasUI - Experimental ๐Ÿงช") with st.sidebar: st.header("๐Ÿ›  Model Selection") model_key = st.selectbox( "Choose Atlas Variant", list(MODELS.keys()), format_func=lambda x: f"{MODELS[x]['name']} {'๐Ÿงช' if MODELS[x]['experimental'] else ''}" ) model_size = st.selectbox( "Choose Model Size", list(MODELS[model_key]["sizes"].keys()) ) if st.button("Load Model"): with st.spinner("Loading model... This may take a few minutes."): status = self.load_model(model_key, model_size) st.success(status) st.header("๐Ÿ”ง Generation Parameters") max_tokens = st.slider("Max New Tokens", min_value=10, max_value=512, value=256, step=10) temperature = st.slider("Temperature", min_value=0.1, max_value=2.0, value=0.4, step=0.1) top_p = st.slider("Top-P", min_value=0.1, max_value=1.0, value=0.9, step=0.1) top_k = st.slider("Top-K", min_value=1, max_value=100, value=50, step=1) if st.button("Clear Chat History"): st.session_state.chat_history = [] st.rerun() st.markdown("*โš ๏ธ CAUTION: Atlas is an experimental model and this is just a preview. Responses may not be expected. Please double-check sensitive information!*") # Display chat history for message in st.session_state.chat_history: with st.chat_message( message["role"], avatar=USER_PFP if message["role"] == "user" else AI_PFP ): st.markdown(message["content"]) # Input box for user messages if prompt := st.chat_input("Message Atlas..."): st.session_state.chat_history.append({"role": "user", "content": prompt}) with st.chat_message("user", avatar=USER_PFP): st.markdown(prompt) with st.chat_message("assistant", avatar=AI_PFP): with st.spinner("Generating response..."): response = self.respond(prompt, max_tokens, temperature, top_p, top_k) st.markdown(response) st.session_state.chat_history.append({"role": "assistant", "content": response}) def run(): try: app = AtlasInferenceApp() app.main() except Exception as e: st.error(f"โš ๏ธ Application Error: {str(e)}") if __name__ == "__main__": run()