File size: 4,718 Bytes
1b2a38a
b2efec6
 
2960eb4
69a84bc
7e978bb
e1807a7
0189af7
094584b
0189af7
e1807a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b2a38a
e1807a7
 
 
 
1b2a38a
e1807a7
 
 
 
 
 
2960eb4
e1807a7
 
 
 
 
 
 
 
 
 
 
1b2a38a
e1807a7
 
 
 
 
ebac108
e1807a7
ebac108
e1807a7
2960eb4
 
 
 
 
e1807a7
2960eb4
e1807a7
 
2960eb4
1b2a38a
2960eb4
 
 
 
 
e1807a7
ebac108
 
e1807a7
ebac108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1807a7
ebac108
 
 
 
 
 
 
e1807a7
ebac108
 
 
 
e1807a7
 
ebac108
 
 
 
 
 
 
 
 
 
 
 
 
 
e1807a7
 
 
ebac108
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
import torch
import os

# Hugging Face token login
HF_TOKEN = os.getenv("HF_TOKEN")
login(token=HF_TOKEN)

# Define models
MODELS = {
    "atlas-flash-1215": {
        "name": "🦁 Atlas-Flash 1215",
        "sizes": {
            "1.5B": "Spestly/Atlas-Flash-1.5B-Preview",
        },
        "emoji": "🦁",
        "experimental": True,
        "is_vision": False,
        "system_prompt_env": "ATLAS_FLASH_1215",
    },
    "atlas-pro-0403": {
        "name": "πŸ† Atlas-Pro 0403",
        "sizes": {
            "1.5B": "Spestly/Atlas-Pro-1.5B-Preview",
        },
        "emoji": "πŸ†",
        "experimental": True,
        "is_vision": False,
        "system_prompt_env": "ATLAS_PRO_0403",
    },
}

# Load default model
default_model_key = "atlas-pro-0403"
default_size = "1.5B"
default_model = MODELS[default_model_key]["sizes"][default_size]

def load_model(model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name, 
        torch_dtype=torch.float32, 
        low_cpu_mem_usage=True
    )
    model.eval()
    return tokenizer, model

tokenizer, model = load_model(default_model)

# Generate response function
def generate_response(message, image, history, model_key, model_size, temperature, top_p, max_new_tokens):
    global tokenizer, model
    selected_model = MODELS[model_key]["sizes"][model_size]
    if selected_model != default_model:
        tokenizer, model = load_model(selected_model)
    
    system_prompt_env = MODELS[model_key]["system_prompt_env"]
    system_prompt = os.getenv(system_prompt_env, "You are an advanced AI system. Help the user as best as you can.")
    
    if MODELS[model_key]["is_vision"]:
        image_info = "An image has been provided as input."
        instruction = f"{system_prompt}\n\n### Instruction:\n{message}\n{image_info}\n\n### Response:"
    else:
        instruction = f"{system_prompt}\n\n### Instruction:\n{message}\n\n### Response:"
    
    inputs = tokenizer(instruction, return_tensors="pt")
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            num_return_sequences=1,
            temperature=temperature,
            top_p=top_p,
            do_sample=True
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    response = response.split("### Response:")[-1].strip()
    return response

def create_interface():
    with gr.Blocks(title="🌟 Atlas-Pro/Flash/Vision Interface", theme="soft") as iface:
        gr.Markdown("Interact with multiple models like Atlas-Pro, Atlas-Flash, and AtlasV-Flash (Coming Soon!). Upload images for vision models!")

        model_key_selector = gr.Dropdown(
            label="Model",
            choices=list(MODELS.keys()),
            value=default_model_key
        )
        model_size_selector = gr.Dropdown(
            label="Model Size",
            choices=list(MODELS[default_model_key]["sizes"].keys()),
            value=default_size
        )
        image_input = gr.Image(label="Upload Image (if applicable)", type="filepath", visible=False)
        message_input = gr.Textbox(label="Message", placeholder="Type your message here...")
        temperature_slider = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.7, step=0.1)
        top_p_slider = gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.1)
        max_tokens_slider = gr.Slider(label="Max New Tokens", minimum=50, maximum=2000, value=1000, step=50)
        chat_output = gr.Chatbot(label="Chatbot")
        submit_button = gr.Button("Submit")

        def update_components(model_key):
            model_info = MODELS[model_key]
            new_sizes = list(model_info["sizes"].keys())
            return [
                gr.Dropdown(choices=new_sizes, value=new_sizes[0]),
                gr.Image(visible=model_info["is_vision"])
            ]

        model_key_selector.change(
            fn=update_components,
            inputs=model_key_selector,
            outputs=[model_size_selector, image_input]
        )

        submit_button.click(
            fn=generate_response,
            inputs=[
                message_input,
                image_input,
                chat_output,
                model_key_selector,
                model_size_selector,
                temperature_slider,
                top_p_slider,
                max_tokens_slider
            ],
            outputs=chat_output
        )

    return iface

create_interface().launch()