|
import streamlit as st |
|
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer |
|
|
|
def load_model(model_name): |
|
try: |
|
st.title('Trying to load tokenizer') |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
st.title('Tokenizer loaded, trying to load model') |
|
model = AutoModelForCausalLM.from_pretrained(model_name) |
|
st.title('Model loaded, initializing pipeline') |
|
generator = pipeline('text-generation', model=model, tokenizer=tokenizer) |
|
st.title('Pipeline ready') |
|
return generator |
|
except Exception as e: |
|
st.error(f"Failed to load model {model_name}: {str(e)}") |
|
return None |
|
|
|
|
|
model_name = "distilgpt2" |
|
generator = load_model(model_name) |
|
|
|
|
|
if generator: |
|
user_prompt = st.text_area("Enter your prompt here:") |
|
|
|
|
|
if st.button('Generate'): |
|
if user_prompt: |
|
|
|
try: |
|
response = generator(user_prompt, max_length=50, clean_up_tokenization_spaces=True) |
|
|
|
st.text_area("Response:", value=response[0]['generated_text'], height=250, disabled=True) |
|
except Exception as e: |
|
st.error(f"Error generating response: {str(e)}") |
|
else: |
|
st.warning("Please enter a prompt.") |
|
else: |
|
st.error("Model could not be loaded. Please ensure the model name is correct and try again.") |
|
|