File size: 3,522 Bytes
03071d5 437cdee 03071d5 1947635 03071d5 4e9028a 03071d5 4e9028a 03071d5 12476bc 03071d5 2616382 3f1d57b 03071d5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
#last version of app.py
from transformers import AutoTokenizer, AutoModelForCausalLM, GPTQConfig
import torch
import optimum
import auto_gptq
import gradio as gr
import time
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
model_name = "TheBloke/zephyr-7B-beta-GPTQ"
tokenizer = AutoTokenizer.from_pretrained(model_name,use_fast=True,padding_side="left")
quantization_config_loading = GPTQConfig(
bits=4,
group_size=128,
disable_exllama=False)
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=quantization_config_loading, device_map="auto")
model = model.to(device)
def generate_text(input_text,max_new_tokens=512,top_k=50,top_p=0.95,temperature=0.7,no_grad=False):
tokenizer.pad_token_id = tokenizer.eos_token_id
input_ids = tokenizer.encode(input_text, padding=True, return_tensors="pt").to(device)
attention_mask = input_ids.ne(tokenizer.pad_token_id).long().to(device)
output = None
if no_grad:
with torch.no_grad():
output = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, top_k=top_k, top_p=top_p, temperature=temperature,do_sample=True)
else:
output = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, top_k=top_k, top_p=top_p, temperature=temperature,do_sample=True)
return tokenizer.decode(output[0], skip_special_tokens=True)
time_story = 0
def generate_response(input,history: list[tuple[str, str]],max_tokens, temperature, top_p):
messages=[]
for val in history:
# Directly access content using "content" key
messages.extend([{"role": "user", "content": val.get("content")}, {"role": "assistant", "content": val.get("content")}]) if val else None
messages.append({"role": "user", "content": input})
start = time.time()
output = generate_text(input,max_new_tokens=max_tokens, top_p=top_p, temperature=temperature)
end = time.time()
time_story= end-start
print(f'Time to generate the story: {time_story}')
history.append((input,output))
yield output
#define the chatinterface
title = "TeLLMyStory"
description = "A LLM for stories generation aiming the reinforcement of the controllability aspect"
theme = gr.Theme.from_hub("Yntec/HaleyCH_Theme_Yellow_Blue")
examples=[["Once upon a time a witch named Malefique was against the wedding of her daughter with the son of the king of the nearby kingdom."],
["Once upon a time an ice-cream met a spoon and they fell in love"],
["The neverending day began with a beautiful sunshine and an AI robot which was seeking humans on the desert Earth."]]
demo = gr.ChatInterface(
generate_response,
type="messages",
title=title,
description=description,
theme=theme,
examples=examples,
additional_inputs=[
gr.Slider(minimum=1, maximum=2048, value=100, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
stop_btn="Stop",
delete_cache=[60,60],
show_progress="full",
save_history=True,
)
if __name__ == "__main__":
demo.launch(share=True,debug=True) |