File size: 2,605 Bytes
86e2a34
12476bc
437cdee
 
86e2a34
12476bc
 
86e2a34
c1ef9c2
5362bf1
86e2a34
 
5362bf1
0fdb2f3
 
 
 
 
5362bf1
c1ef9c2
 
 
 
 
5362bf1
 
 
 
 
 
0fdb2f3
5362bf1
 
 
0fdb2f3
5362bf1
0fdb2f3
 
5362bf1
 
86e2a34
5362bf1
 
 
54f41e4
 
 
5362bf1
 
54f41e4
0fdb2f3
12476bc
0fdb2f3
 
5362bf1
86e2a34
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
from transformers.utils import logging
import gradio as gr

# Define the logger instance for the transformers library
logger = logging.get_logger("transformers")

# Load the model and tokenizer
model_name = "TheBloke/Llama-2-7B-Chat-GGML" #"openai-community/gpt2" or "TheBloke/Mistral-7B-Instruct-v0.1-GPTQ" or "TheBloke/Llama-2-7B-Chat-GGML" or "TheBloke/zephyr-7B-beta-GPTQ"
tokenizer = AutoTokenizer.from_pretrained(model_name)
#model = AutoModelForSeq2SeqLM.from_pretrained(model_name) 
model = AutoModelForCausalLM.from_pretrained(model_name)

# Generate text using the model and tokenizer
def generate_text(input_text):
    input_ids = tokenizer.encode(input_text, return_tensors="pt")
    output = model.generate(input_ids, max_length=100, num_return_sequences=1, no_repeat_ngram_size=2, top_k=50, top_p=0.95, temperature=0.7)
    return tokenizer.decode(output[0], skip_special_tokens=True)

# def generate_text(prompt):
#     inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512, padding="max_length")
#     summary_ids = model.generate(inputs["input_ids"], max_new_tokens=512, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True)
#     return tokenizer.decode(summary_ids[0], skip_special_tokens=True)

# #for training the model after the data is collected
# #model.save_pretrained("model")
# #tokenizer.save_pretrained("model")

# #for the app functions

# def show_output_text(message):
#     history.append((message,""))
#     story = generate_text(message)
#     history[-1] = (message,story)
#     return story

# def clear_textbox():
#     return None,None

# # Créer une interface de saisie avec Gradio
interface = gr.Interface(fn=generate_text, inputs="text", outputs="text",title="TeLLMyStory",description="Enter your story idea and the model will generate the story based on it.")
# with gr.Blocks() as demo:
#     gr.Markdown("TeLLMyStory chatbot")
#     with gr.Row():
#         input_text = gr.Textbox(label="Enter your story idea here", placeholder="Once upon a time...")
#         clear_button = gr.Button("Clear",variant="secondary")
#         submit_button = gr.Button("Submit", variant="primary")

#     with gr.Row():
#         gr.Markdown("And see the story take shape here")
#         output_text = gr.Textbox(label="History")
    
#     submit_button.click(fn=show_output_text, inputs=input_text,outputs=output_text)
#     clear_button.click(fn=clear_textbox,outputs=[input_text,output_text])
# # Lancer l'interface
interface.launch()