fancyfeast
doop
2e62a51
raw
history blame
2.81 kB
import gradio as gr
import spaces
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer, TextIteratorStreamer
import torch
from threading import Thread
from typing import Generator
# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("./lora_model")
model = AutoPeftModelForCausalLM.from_pretrained("./lora_model", device_map=0, torch_dtype="auto")
@spaces.GPU()
@torch.no_grad()
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
) -> Generator[str, None, None]:
torch.cuda.empty_cache()
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
convo_string = tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
assert isinstance(convo_string, str)
# Tokenize the conversation
convo_tokens = tokenizer.encode(convo_string, add_special_tokens=False, truncation=False)
input_ids = torch.tensor(convo_tokens, dtype=torch.long)
attention_mask = torch.ones_like(input_ids)
# Move to GPU
input_ids = input_ids.unsqueeze(0).to("cuda")
attention_mask = attention_mask.unsqueeze(0).to("cuda")
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_tokens,
do_sample=True,
suppress_tokens=None,
use_cache=True,
temperature=temperature,
top_k=None,
top_p=top_p,
streamer=streamer,
)
if temperature == 0:
generate_kwargs["do_sample"] = False
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are a helpful image generation prompt writing AI. You write image generation prompts based on user requests. The prompt you write should be 150 words or longer.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.6, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
)
if __name__ == "__main__":
demo.launch()