File size: 3,522 Bytes
03071d5
 
 
 
 
437cdee
03071d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1947635
03071d5
 
 
 
 
 
 
 
 
 
 
4e9028a
03071d5
 
 
 
4e9028a
03071d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12476bc
03071d5
 
 
 
 
2616382
3f1d57b
03071d5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#last version of app.py
from transformers import AutoTokenizer, AutoModelForCausalLM, GPTQConfig
import torch
import optimum
import auto_gptq
import gradio as gr
import time

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

model_name = "TheBloke/zephyr-7B-beta-GPTQ"

tokenizer = AutoTokenizer.from_pretrained(model_name,use_fast=True,padding_side="left")
quantization_config_loading = GPTQConfig(
                                bits=4,
                                group_size=128,
                                disable_exllama=False)
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=quantization_config_loading, device_map="auto")
model = model.to(device)

def generate_text(input_text,max_new_tokens=512,top_k=50,top_p=0.95,temperature=0.7,no_grad=False):
    tokenizer.pad_token_id = tokenizer.eos_token_id
    input_ids = tokenizer.encode(input_text, padding=True, return_tensors="pt").to(device)
    attention_mask = input_ids.ne(tokenizer.pad_token_id).long().to(device)
    output = None
    if no_grad:
        with torch.no_grad():
            output = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, top_k=top_k, top_p=top_p, temperature=temperature,do_sample=True)
    else:
        output = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, top_k=top_k, top_p=top_p, temperature=temperature,do_sample=True)
    return tokenizer.decode(output[0], skip_special_tokens=True)


time_story = 0

def generate_response(input,history: list[tuple[str, str]],max_tokens, temperature, top_p):
    messages=[]
    for val in history:
        # Directly access content using "content" key
        messages.extend([{"role": "user", "content": val.get("content")}, {"role": "assistant", "content": val.get("content")}]) if val else None

    messages.append({"role": "user", "content": input})
    
    start = time.time()
    output = generate_text(input,max_new_tokens=max_tokens, top_p=top_p, temperature=temperature)
    end = time.time()
    time_story= end-start
    print(f'Time to generate the story: {time_story}')
    history.append((input,output))
    yield output

#define the chatinterface
title = "TeLLMyStory"
description = "A LLM for stories generation aiming the reinforcement of the controllability aspect"
theme = gr.Theme.from_hub("Yntec/HaleyCH_Theme_Yellow_Blue")
examples=[["Once upon a time a witch named Malefique was against the wedding of her daughter with the son of the king of the nearby kingdom."],
        ["Once upon a time an ice-cream met a spoon and they fell in love"],
        ["The neverending day began with a beautiful sunshine and an AI robot which was seeking humans on the desert Earth."]]

demo = gr.ChatInterface(
      generate_response,
      type="messages",
      title=title,
      description=description,
      theme=theme,
      examples=examples,
      additional_inputs=[
          gr.Slider(minimum=1, maximum=2048, value=100, step=1, label="Max new tokens"),
          gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
          gr.Slider(
              minimum=0.1,
              maximum=1.0,
              value=0.95,
              step=0.05,
              label="Top-p (nucleus sampling)",
          ),
      ],
    
    stop_btn="Stop",
    delete_cache=[60,60],
    show_progress="full",
    save_history=True,
  )


if __name__ == "__main__":
    demo.launch(share=True,debug=True)