import gradio as gr import spaces from peft import AutoPeftModelForCausalLM from transformers import AutoTokenizer, TextIteratorStreamer import torch from threading import Thread from typing import Generator # Load the model and tokenizer tokenizer = AutoTokenizer.from_pretrained("./lora_model") model = AutoPeftModelForCausalLM.from_pretrained("./lora_model", device_map=0, torch_dtype="auto") @spaces.GPU() @torch.no_grad() def respond( message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p, ) -> Generator[str, None, None]: torch.cuda.empty_cache() messages = [{"role": "system", "content": system_message}] for val in history: if val[0]: messages.append({"role": "user", "content": val[0]}) if val[1]: messages.append({"role": "assistant", "content": val[1]}) messages.append({"role": "user", "content": message}) convo_string = tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True) assert isinstance(convo_string, str) # Tokenize the conversation convo_tokens = tokenizer.encode(convo_string, add_special_tokens=False, truncation=False) input_ids = torch.tensor(convo_tokens, dtype=torch.long) attention_mask = torch.ones_like(input_ids) # Move to GPU input_ids = input_ids.unsqueeze(0).to("cuda") attention_mask = attention_mask.unsqueeze(0).to("cuda") streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=max_tokens, do_sample=True, suppress_tokens=None, use_cache=True, temperature=temperature, top_k=None, top_p=top_p, streamer=streamer, ) if temperature == 0: generate_kwargs["do_sample"] = False t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() outputs = ["score_7_up,"] for text in streamer: outputs.append(text) yield "".join(outputs) demo = gr.ChatInterface( respond, additional_inputs=[ gr.Textbox(value="You are a helpful image generation prompt writing AI. You write image generation prompts based on user requests. The prompt you write should be 150 words or longer.", label="System message"), gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), gr.Slider(minimum=0.1, maximum=4.0, value=0.6, step=0.1, label="Temperature"), gr.Slider( minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)", ), ], examples=[ ["Please write a random prompt."], ["I'd like an image based on the tags: black and white, two women, gym, minimalist design, exposed beams, kneeling, holding head, casual wear."], ["Can you create an image of a woman hiking and resting on a rock in a beautiful forest with mountains?"], ["can u make a creepy hallway pic, like something out of a weird dream, with shadows and a mysterious figure at the end? maybe some reds and blacks, make it look kinda eerie and otherworldly pls"], ["Beach sunset with silhouettes on rocks and birds flying"], ], cache_examples=False, ) if __name__ == "__main__": demo.launch()