Spestly commited on
Commit
2960eb4
Β·
verified Β·
1 Parent(s): bd64553

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -162
app.py CHANGED
@@ -1,173 +1,56 @@
1
- import gc
2
- import torch
3
  import gradio as gr
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
5
- from huggingface_hub import login
6
- import os
7
 
8
- # Load Hugging Face token
9
  HF_TOKEN = os.getenv("HF_TOKEN")
10
  login(token=HF_TOKEN)
11
 
12
- # Define models
13
- MODELS = {
14
- "atlas-flash-1215": {
15
- "name": "🦁 Atlas-Flash 1215",
16
- "sizes": {
17
- "1.5B": "Spestly/Atlas-Flash-1.5B-Preview",
18
- },
19
- "emoji": "🦁",
20
- "experimental": True,
21
- "is_vision": False,
22
- "system_prompt_env": "ATLAS_FLASH_1215",
23
- },
24
- "atlas-pro-0403": {
25
- "name": "πŸ† Atlas-Pro 0403",
26
- "sizes": {
27
- "1.5B": "Spestly/Atlas-Pro-1.5B-Preview",
28
- },
29
- "emoji": "πŸ†",
30
- "experimental": True,
31
- "is_vision": False,
32
- "system_prompt_env": "ATLAS_PRO_0403",
33
- },
34
- }
35
-
36
- # Clear memory
37
- def clear_memory():
38
- if torch.cuda.is_available():
39
- torch.cuda.empty_cache()
40
- gc.collect()
41
-
42
- # Load model
43
- def load_model(model_key, model_size):
44
- try:
45
- clear_memory()
46
-
47
- # Unload previous model if any
48
- global current_model
49
- if current_model is not None:
50
- del current_model["model"]
51
- del current_model["tokenizer"]
52
- clear_memory()
53
-
54
- model_path = MODELS[model_key]["sizes"][model_size]
55
-
56
- tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
57
- model = AutoModelForCausalLM.from_pretrained(
58
- model_path,
59
- device_map="auto",
60
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
61
- trust_remote_code=True,
62
- low_cpu_mem_usage=True
63
- )
64
-
65
- current_model.update({
66
- "tokenizer": tokenizer,
67
- "model": model,
68
- "config": {
69
- "name": f"{MODELS[model_key]['name']} {model_size}",
70
- "path": model_path,
71
- "system_prompt": os.getenv(MODELS[model_key]["system_prompt_env"], "Default system prompt"),
72
- }
73
- })
74
- return f"βœ… {MODELS[model_key]['name']} {model_size} loaded successfully!"
75
- except Exception as e:
76
- return f"❌ Error: {str(e)}"
77
-
78
- # Respond to input
79
- def respond(prompt, max_tokens, temperature, top_p, top_k):
80
- if not current_model["model"] or not current_model["tokenizer"]:
81
- return "⚠️ Please select and load a model first"
82
-
83
- try:
84
- system_prompt = current_model["config"]["system_prompt"]
85
- if not system_prompt:
86
- return "⚠️ System prompt not found for the selected model."
87
-
88
- full_prompt = f"{system_prompt}\n\n### Instruction:\n{prompt}\n\n### Response:"
89
-
90
- inputs = current_model["tokenizer"](
91
- full_prompt,
92
- return_tensors="pt",
93
- max_length=512,
94
- truncation=True,
95
- padding=True
96
- )
97
- with torch.no_grad():
98
- output = current_model["model"].generate(
99
- input_ids=inputs.input_ids,
100
- attention_mask=inputs.attention_mask,
101
- max_new_tokens=max_tokens,
102
- temperature=temperature,
103
- top_p=top_p,
104
- top_k=top_k,
105
- do_sample=True,
106
- pad_token_id=current_model["tokenizer"].pad_token_id,
107
- eos_token_id=current_model["tokenizer"].eos_token_id,
108
- )
109
- response = current_model["tokenizer"].decode(output[0], skip_special_tokens=True)
110
 
111
- if full_prompt in response:
112
- response = response.replace(full_prompt, "").strip()
 
113
 
114
- return response
115
- except Exception as e:
116
- return f"⚠️ Generation Error: {str(e)}"
117
- finally:
118
- clear_memory()
119
 
120
- # Initialize model storage
121
- current_model = {"tokenizer": None, "model": None, "config": None}
122
 
123
- # UI for Gradio
124
- def gradio_ui():
125
- def load_and_set_model(model_key, model_size):
126
- return load_model(model_key, model_size)
 
 
 
127
 
128
- with gr.Blocks() as app:
129
- gr.Markdown("## 🦁 Atlas Inference Platform - Experimental πŸ§ͺ")
130
-
131
- with gr.Row():
132
- model_key_dropdown = gr.Dropdown(
133
- choices=list(MODELS.keys()),
134
- value=list(MODELS.keys())[0],
135
- label="Select Model Variant",
136
- interactive=True
137
- )
138
- model_size_dropdown = gr.Dropdown(
139
- choices=list(MODELS[list(MODELS.keys())[0]]["sizes"].keys()),
140
- value="1.5B",
141
- label="Select Model Size",
142
- interactive=True
143
- )
144
- load_button = gr.Button("Load Model")
145
-
146
- load_status = gr.Textbox(label="Model Load Status", interactive=False)
147
-
148
- load_button.click(
149
- load_and_set_model,
150
- inputs=[model_key_dropdown, model_size_dropdown],
151
- outputs=load_status,
152
- )
153
-
154
- with gr.Row():
155
- prompt_input = gr.Textbox(label="Input Prompt", lines=4)
156
- max_tokens_slider = gr.Slider(10, 512, value=256, step=10, label="Max Tokens")
157
- temperature_slider = gr.Slider(0.1, 2.0, value=0.4, step=0.1, label="Temperature")
158
- top_p_slider = gr.Slider(0.1, 1.0, value=0.9, step=0.1, label="Top-P")
159
- top_k_slider = gr.Slider(1, 100, value=50, step=1, label="Top-K")
160
-
161
- generate_button = gr.Button("Generate Response")
162
- response_output = gr.Textbox(label="Model Response", lines=6, interactive=False)
163
-
164
- generate_button.click(
165
- respond,
166
- inputs=[prompt_input, max_tokens_slider, temperature_slider, top_p_slider, top_k_slider],
167
- outputs=response_output,
168
  )
169
-
170
- return app
171
-
172
- if __name__ == "__main__":
173
- gradio_ui().launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, login
3
+ import torch
 
4
 
 
5
  HF_TOKEN = os.getenv("HF_TOKEN")
6
  login(token=HF_TOKEN)
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ model_name = "Spestly/Atlas-Pro-1.5B-Preview"
10
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32, low_cpu_mem_usage=True)
12
 
 
 
 
 
 
13
 
14
+ model.eval()
 
15
 
16
+ def generate_response(message, history):
17
+ instruction = (
18
+ "You are an LLM called Atlas. You are finetuned by Aayan Mishra. You are NOT trained by Anthropic. "
19
+ "You are a Qwen 2.5 fine-tune. Your purpose is the help the user accomplish their request to the best of your abilities. "
20
+ "Below is an instruction that describes a task. Answer it clearly and concisely.\n\n"
21
+ f"### Instruction:\n{message}\n\n### Response:"
22
+ )
23
 
24
+ inputs = tokenizer(instruction, return_tensors="pt")
25
+
26
+ with torch.no_grad():
27
+ outputs = model.generate(
28
+ **inputs,
29
+ max_new_tokens=1000,
30
+ num_return_sequences=1,
31
+ temperature=0.7,
32
+ top_p=0.9,
33
+ do_sample=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  )
35
+
36
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
37
+ response = response.split("### Response:")[-1].strip()
38
+
39
+ return response
40
+
41
+ iface = gr.ChatInterface(
42
+ generate_response,
43
+ chatbot=gr.Chatbot(height=600, type="messages"),
44
+ textbox=gr.Textbox(placeholder="Type your message here...", container=False, scale=7),
45
+ title="🦁 Atlas-Pro",
46
+ description="Chat with Alas-Pro",
47
+ theme="citrus",
48
+ examples=[
49
+ "Can you give me a good salsa recipe?",
50
+ "Write an engaging two-line horror story.",
51
+ "What is the capital of Australia?",
52
+ ],
53
+ type="messages"
54
+ )
55
+
56
+ iface.launch()