Spaces:
Sleeping
Sleeping
File size: 4,718 Bytes
1b2a38a b2efec6 2960eb4 69a84bc 7e978bb e1807a7 0189af7 094584b 0189af7 e1807a7 1b2a38a e1807a7 1b2a38a e1807a7 2960eb4 e1807a7 1b2a38a e1807a7 ebac108 e1807a7 ebac108 e1807a7 2960eb4 e1807a7 2960eb4 e1807a7 2960eb4 1b2a38a 2960eb4 e1807a7 ebac108 e1807a7 ebac108 e1807a7 ebac108 e1807a7 ebac108 e1807a7 ebac108 e1807a7 ebac108 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
import torch
import os
# Hugging Face token login
HF_TOKEN = os.getenv("HF_TOKEN")
login(token=HF_TOKEN)
# Define models
MODELS = {
"atlas-flash-1215": {
"name": "π¦ Atlas-Flash 1215",
"sizes": {
"1.5B": "Spestly/Atlas-Flash-1.5B-Preview",
},
"emoji": "π¦",
"experimental": True,
"is_vision": False,
"system_prompt_env": "ATLAS_FLASH_1215",
},
"atlas-pro-0403": {
"name": "π Atlas-Pro 0403",
"sizes": {
"1.5B": "Spestly/Atlas-Pro-1.5B-Preview",
},
"emoji": "π",
"experimental": True,
"is_vision": False,
"system_prompt_env": "ATLAS_PRO_0403",
},
}
# Load default model
default_model_key = "atlas-pro-0403"
default_size = "1.5B"
default_model = MODELS[default_model_key]["sizes"][default_size]
def load_model(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32,
low_cpu_mem_usage=True
)
model.eval()
return tokenizer, model
tokenizer, model = load_model(default_model)
# Generate response function
def generate_response(message, image, history, model_key, model_size, temperature, top_p, max_new_tokens):
global tokenizer, model
selected_model = MODELS[model_key]["sizes"][model_size]
if selected_model != default_model:
tokenizer, model = load_model(selected_model)
system_prompt_env = MODELS[model_key]["system_prompt_env"]
system_prompt = os.getenv(system_prompt_env, "You are an advanced AI system. Help the user as best as you can.")
if MODELS[model_key]["is_vision"]:
image_info = "An image has been provided as input."
instruction = f"{system_prompt}\n\n### Instruction:\n{message}\n{image_info}\n\n### Response:"
else:
instruction = f"{system_prompt}\n\n### Instruction:\n{message}\n\n### Response:"
inputs = tokenizer(instruction, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
num_return_sequences=1,
temperature=temperature,
top_p=top_p,
do_sample=True
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response.split("### Response:")[-1].strip()
return response
def create_interface():
with gr.Blocks(title="π Atlas-Pro/Flash/Vision Interface", theme="soft") as iface:
gr.Markdown("Interact with multiple models like Atlas-Pro, Atlas-Flash, and AtlasV-Flash (Coming Soon!). Upload images for vision models!")
model_key_selector = gr.Dropdown(
label="Model",
choices=list(MODELS.keys()),
value=default_model_key
)
model_size_selector = gr.Dropdown(
label="Model Size",
choices=list(MODELS[default_model_key]["sizes"].keys()),
value=default_size
)
image_input = gr.Image(label="Upload Image (if applicable)", type="filepath", visible=False)
message_input = gr.Textbox(label="Message", placeholder="Type your message here...")
temperature_slider = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.7, step=0.1)
top_p_slider = gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.1)
max_tokens_slider = gr.Slider(label="Max New Tokens", minimum=50, maximum=2000, value=1000, step=50)
chat_output = gr.Chatbot(label="Chatbot")
submit_button = gr.Button("Submit")
def update_components(model_key):
model_info = MODELS[model_key]
new_sizes = list(model_info["sizes"].keys())
return [
gr.Dropdown(choices=new_sizes, value=new_sizes[0]),
gr.Image(visible=model_info["is_vision"])
]
model_key_selector.change(
fn=update_components,
inputs=model_key_selector,
outputs=[model_size_selector, image_input]
)
submit_button.click(
fn=generate_response,
inputs=[
message_input,
image_input,
chat_output,
model_key_selector,
model_size_selector,
temperature_slider,
top_p_slider,
max_tokens_slider
],
outputs=chat_output
)
return iface
create_interface().launch() |