llm-t97 / chatt5-small.py
ysn-rfd's picture
Upload 22 files
5500979 verified
raw
history blame contribute delete
No virus
1.16 kB
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
# Load the fine-tuned model and tokenizer
model_path = "./t5-small-finetuned"
tokenizer = T5Tokenizer.from_pretrained(model_path)
model = T5ForConditionalGeneration.from_pretrained(model_path)
# Ensure model is in evaluation mode
model.eval()
def chat_with_model(prompt):
# Encode the input text
input_ids = tokenizer.encode(prompt, return_tensors='pt')
# Generate a response
with torch.no_grad():
output = model.generate(
input_ids,
max_length=150, # Adjust as needed
num_beams=5, # Use beam search for better results
early_stopping=True
)
# Decode the response and return
response = tokenizer.decode(output[0], skip_special_tokens=True)
return response
def main():
print("Chatbot is running. Type 'exit' to end the conversation.")
while True:
user_input = input("You: ")
if user_input.lower() == 'exit':
break
response = chat_with_model(user_input)
print(f"Bot: {response}")
if __name__ == "__main__":
main()