File size: 2,605 Bytes
86e2a34 12476bc 437cdee 86e2a34 12476bc 86e2a34 c1ef9c2 5362bf1 86e2a34 5362bf1 0fdb2f3 5362bf1 c1ef9c2 5362bf1 0fdb2f3 5362bf1 0fdb2f3 5362bf1 0fdb2f3 5362bf1 86e2a34 5362bf1 54f41e4 5362bf1 54f41e4 0fdb2f3 12476bc 0fdb2f3 5362bf1 86e2a34 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
from transformers.utils import logging
import gradio as gr
# Define the logger instance for the transformers library
logger = logging.get_logger("transformers")
# Load the model and tokenizer
model_name = "TheBloke/Llama-2-7B-Chat-GGML" #"openai-community/gpt2" or "TheBloke/Mistral-7B-Instruct-v0.1-GPTQ" or "TheBloke/Llama-2-7B-Chat-GGML" or "TheBloke/zephyr-7B-beta-GPTQ"
tokenizer = AutoTokenizer.from_pretrained(model_name)
#model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# Generate text using the model and tokenizer
def generate_text(input_text):
input_ids = tokenizer.encode(input_text, return_tensors="pt")
output = model.generate(input_ids, max_length=100, num_return_sequences=1, no_repeat_ngram_size=2, top_k=50, top_p=0.95, temperature=0.7)
return tokenizer.decode(output[0], skip_special_tokens=True)
# def generate_text(prompt):
# inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512, padding="max_length")
# summary_ids = model.generate(inputs["input_ids"], max_new_tokens=512, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True)
# return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
# #for training the model after the data is collected
# #model.save_pretrained("model")
# #tokenizer.save_pretrained("model")
# #for the app functions
# def show_output_text(message):
# history.append((message,""))
# story = generate_text(message)
# history[-1] = (message,story)
# return story
# def clear_textbox():
# return None,None
# # Créer une interface de saisie avec Gradio
interface = gr.Interface(fn=generate_text, inputs="text", outputs="text",title="TeLLMyStory",description="Enter your story idea and the model will generate the story based on it.")
# with gr.Blocks() as demo:
# gr.Markdown("TeLLMyStory chatbot")
# with gr.Row():
# input_text = gr.Textbox(label="Enter your story idea here", placeholder="Once upon a time...")
# clear_button = gr.Button("Clear",variant="secondary")
# submit_button = gr.Button("Submit", variant="primary")
# with gr.Row():
# gr.Markdown("And see the story take shape here")
# output_text = gr.Textbox(label="History")
# submit_button.click(fn=show_output_text, inputs=input_text,outputs=output_text)
# clear_button.click(fn=clear_textbox,outputs=[input_text,output_text])
# # Lancer l'interface
interface.launch()
|