File size: 4,025 Bytes
ab382f0
 
 
 
 
 
 
 
 
 
 
 
c34a36c
 
ab382f0
 
 
 
f2fdba7
ab382f0
 
 
 
 
2ae46d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ec3e64
f2fdba7
 
2ae46d7
 
ab382f0
 
 
 
 
 
 
 
 
 
 
 
 
f0831ab
2ae46d7
 
 
 
 
 
 
 
 
 
 
 
 
 
f2fdba7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
import json
import subprocess
from threading import Thread

import torch
import spaces
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer

subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)

MODEL_ID = "nikravan/Marco_o1_q4"

CHAT_TEMPLATE = "ChatML"
MODEL_NAME = MODEL_ID.split("/")[-1]
CONTEXT_LENGTH = 16000

COLOR = "black"  # تغییر رنگ به مشکی
EMOJI = "🤖"
DESCRIPTION = f"This is the {MODEL_NAME} model designed for testing thinking for general AI tasks."


@spaces.GPU()
def predict(message, history, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p):
    if CHAT_TEMPLATE == "Auto":
        stop_tokens = [tokenizer.eos_token_id]
        instruction = system_prompt + "\n\n"
        for user, assistant in history:
            instruction += f"User: {user}\nAssistant: {assistant}\n"
        instruction += f"User: {message}\nAssistant:"
    elif CHAT_TEMPLATE == "ChatML":
        stop_tokens = ["<|endoftext|>", "<|im_end|>"]
        instruction = '<|im_start|>system\n' + system_prompt + '\n<|im_end|>\n'
        for user, assistant in history:
            instruction += f'<|im_start|>user\n{user}\n<|im_end|>\n<|im_start|>assistant\n{assistant}\n<|im_end|>\n'
        instruction += f'<|im_start|>user\n{message}\n<|im_end|>\n<|im_start|>assistant\n'
    elif CHAT_TEMPLATE == "Mistral Instruct":
        stop_tokens = ["</s>", "[INST]", "[INST] ", "<s>", "[/INST]", "[/INST] "]
        instruction = f'<s>[INST] {system_prompt}\n'
        for user, assistant in history:
            instruction += f'{user} [/INST] {assistant}</s>[INST]'
        instruction += f' {message} [/INST]'
    else:
        raise Exception("Incorrect chat template, select 'Auto', 'ChatML' or 'Mistral Instruct'")
    
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    enc = tokenizer(instruction, return_tensors="pt", padding=True, truncation=True)
    input_ids, attention_mask = enc.input_ids, enc.attention_mask

    if input_ids.shape[1] > CONTEXT_LENGTH:
        input_ids = input_ids[:, -CONTEXT_LENGTH:]
        attention_mask = attention_mask[:, -CONTEXT_LENGTH:]

    generate_kwargs = dict(
        input_ids=input_ids.to(device),
        attention_mask=attention_mask.to(device),
        streamer=streamer,
        do_sample=True,
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_k=top_k,
        repetition_penalty=repetition_penalty,
        top_p=top_p
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()
    outputs = []
    for new_token in streamer:
        outputs.append(new_token)
        if new_token in stop_tokens:
            break
        result = "".join(outputs)
        # تغییر قالب به Markdown و LaTeX
        yield f"### $$ {result} $$"


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained('AIDC-AI/Marco-o1')
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    quantization_config=quantization_config,
    attn_implementation="flash_attention_2",
)

gr.ChatInterface(
    predict,
    title=EMOJI + " " + MODEL_NAME,
    description=DESCRIPTION,
    additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False),
    additional_inputs=[
        gr.Textbox("You are a code assistant.", label="System prompt"),
        gr.Slider(0, 1, 0.3, label="Temperature"),
        gr.Slider(128, 4096, 1024, label="Max new tokens"),
        gr.Slider(1, 80, 40, label="Top K sampling"),
        gr.Slider(0, 2, 1.1, label="Repetition penalty"),
        gr.Slider(0, 1, 0.95, label="Top P sampling"),
    ],
    theme=gr.themes.Soft(primary_hue=COLOR),
).queue().launch()