|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, GPTQConfig |
|
import torch |
|
import optimum |
|
import auto_gptq |
|
import gradio as gr |
|
import time |
|
|
|
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") |
|
|
|
model_name = "TheBloke/zephyr-7B-beta-GPTQ" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name,use_fast=True,padding_side="left") |
|
quantization_config_loading = GPTQConfig( |
|
bits=4, |
|
group_size=128, |
|
disable_exllama=False) |
|
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=quantization_config_loading, device_map="auto") |
|
model = model.to(device) |
|
|
|
def generate_text(input_text,max_new_tokens=512,top_k=50,top_p=0.95,temperature=0.7,no_grad=False): |
|
tokenizer.pad_token_id = tokenizer.eos_token_id |
|
input_ids = tokenizer.encode(input_text, padding=True, return_tensors="pt").to(device) |
|
attention_mask = input_ids.ne(tokenizer.pad_token_id).long().to(device) |
|
output = None |
|
if no_grad: |
|
with torch.no_grad(): |
|
output = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, top_k=top_k, top_p=top_p, temperature=temperature,do_sample=True) |
|
else: |
|
output = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, top_k=top_k, top_p=top_p, temperature=temperature,do_sample=True) |
|
return tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
|
|
|
time_story = 0 |
|
|
|
def generate_response(input,history: list[tuple[str, str]],max_tokens, temperature, top_p): |
|
messages=[] |
|
for val in history: |
|
|
|
messages.extend([{"role": "user", "content": val.get("content")}, {"role": "assistant", "content": val.get("content")}]) if val else None |
|
|
|
messages.append({"role": "user", "content": input}) |
|
|
|
start = time.time() |
|
output = generate_text(input,max_new_tokens=max_tokens, top_p=top_p, temperature=temperature) |
|
end = time.time() |
|
time_story= end-start |
|
print(f'Time to generate the story: {time_story}') |
|
history.append((input,output)) |
|
yield output |
|
|
|
|
|
title = "TeLLMyStory" |
|
description = "A LLM for stories generation aiming the reinforcement of the controllability aspect" |
|
theme = gr.Theme.from_hub("Yntec/HaleyCH_Theme_Yellow_Blue") |
|
examples=[["Once upon a time a witch named Malefique was against the wedding of her daughter with the son of the king of the nearby kingdom."], |
|
["Once upon a time an ice-cream met a spoon and they fell in love"], |
|
["The neverending day began with a beautiful sunshine and an AI robot which was seeking humans on the desert Earth."]] |
|
|
|
demo = gr.ChatInterface( |
|
generate_response, |
|
type="messages", |
|
title=title, |
|
description=description, |
|
theme=theme, |
|
examples=examples, |
|
additional_inputs=[ |
|
gr.Slider(minimum=1, maximum=2048, value=100, step=1, label="Max new tokens"), |
|
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), |
|
gr.Slider( |
|
minimum=0.1, |
|
maximum=1.0, |
|
value=0.95, |
|
step=0.05, |
|
label="Top-p (nucleus sampling)", |
|
), |
|
], |
|
|
|
stop_btn="Stop", |
|
delete_cache=[60,60], |
|
show_progress="full", |
|
save_history=True, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True,debug=True) |