File size: 3,537 Bytes
2616382
12476bc
437cdee
1e59ffd
437cdee
86e2a34
12476bc
 
86e2a34
2616382
 
 
 
5362bf1
72903e4
8d9e0dc
f694567
 
 
 
 
 
 
72903e4
0fdb2f3
1e59ffd
0fdb2f3
f694567
5a33e76
f694567
4faf856
f694567
 
2616382
 
 
 
5362bf1
2d4b9ba
72903e4
 
2d4b9ba
c1ef9c2
 
 
 
 
5362bf1
 
 
 
 
 
0fdb2f3
5362bf1
 
 
0fdb2f3
5362bf1
0fdb2f3
 
5362bf1
 
2616382
5362bf1
 
 
54f41e4
 
 
5362bf1
 
54f41e4
0fdb2f3
12476bc
0fdb2f3
 
5362bf1
2616382
3f1d57b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from transformers.utils import logging
import gradio as gr
#import spaces

# Define the logger instance for the transformers library
logger = logging.get_logger("transformers")

# Load the model and tokenizer
model_name = "TheBloke/Mistral-7B-Instruct-v0.1-GPTQ" #"openai-community/gpt2" or "TheBloke/Mistral-7B-Instruct-v0.1-GPTQ" or "TheBloke/Llama-2-7B-Chat-GGML" or "TheBloke/zephyr-7B-beta-GPTQ"
tokenizer = AutoTokenizer.from_pretrained(model_name,use_fast=True)
model = AutoModelForCausalLM.from_pretrained(model_name,device_map="auto",trust_remote_code=False,revision="main")
#tokenizer.pad_token_id = tokenizer.eos_token_id

#transfer model on GPU
#model.to("cuda")
# pipe = pipeline("text-generation", model=model_name, tokenizer=tokenizer, 
#         max_new_tokens=512,
#         do_sample=True,
#         temperature=0.7,
#         top_p=0.95,
#         top_k=40,
#         repetition_penalty=1.1)

# Generate text using the model and tokenizer
#@spaces.GPU(duration=60)
def generate_text(input_text):
    input_ids = tokenizer.encode(input_text, return_tensors="pt")#.to("cuda")
    #attention_mask = input_ids.ne(tokenizer.pad_token_id).long()
    output = model.generate(input_ids, max_new_tokens=512, top_k=50, top_p=0.95, temperature=0.7, do_sample=True)# attention_mask=attention_mask, max_length=100, num_return_sequences=1, no_repeat_ngram_size=2, top_k=50, top_p=0.95, temperature=0.7, do_sample=True)
    #output = model.generate(input_ids) #, attention_mask=attention_mask, max_length=100, num_return_sequences=1, no_repeat_ngram_size=2, top_k=50, top_p=0.95, temperature=0.7, do_sample=True)
    return tokenizer.decode(output[0])
    #return pipe(input_text)[0]["generated_text"]

interface = gr.Interface(fn=generate_text, inputs="text", outputs="text",title="TeLLMyStory",description="Enter your story idea and the model will generate the story based on it.")
interface.launch()


# Example of disabling Exllama backend (if applicable in your configuration)
#config = {"disable_exllama": True}
#model.config.update(config)

# def generate_text(prompt):
#     inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512, padding="max_length")
#     summary_ids = model.generate(inputs["input_ids"], max_new_tokens=512, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True)
#     return tokenizer.decode(summary_ids[0], skip_special_tokens=True)

# #for training the model after the data is collected
# #model.save_pretrained("model")
# #tokenizer.save_pretrained("model")

# #for the app functions

# def show_output_text(message):
#     history.append((message,""))
#     story = generate_text(message)
#     history[-1] = (message,story)
#     return story

# def clear_textbox():
#     return None,None

# # Créer une interface de saisie avec Gradio

# with gr.Blocks() as demo:
#     gr.Markdown("TeLLMyStory chatbot")
#     with gr.Row():
#         input_text = gr.Textbox(label="Enter your story idea here", placeholder="Once upon a time...")
#         clear_button = gr.Button("Clear",variant="secondary")
#         submit_button = gr.Button("Submit", variant="primary")

#     with gr.Row():
#         gr.Markdown("And see the story take shape here")
#         output_text = gr.Textbox(label="History")
    
#     submit_button.click(fn=show_output_text, inputs=input_text,outputs=output_text)
#     clear_button.click(fn=clear_textbox,outputs=[input_text,output_text])
# # Lancer l'interface