Example script for using this model
Hello all!
I was testing several of these Switch models and decided to share some of my code for testing. It has simple chat history which, unless you know what you're doing with a Switch model, the history can really confuse things.
Anyway, I tested using CPU only. This model while running, used an additional 60GB of ram, I would make sure you use this on a computer with at least 96GB of ram. Here is the script, if you have any questions, feel free to ask me:
import time
from transformers import AutoTokenizer, SwitchTransformersForConditionalGeneration
print(f"Loading model, please wait...")
# Start timer for the entire script
start_time = time.time()
# Timer for model loading
loading_start_time = time.time()
tokenizer = AutoTokenizer.from_pretrained("google/switch-base-256")
model = SwitchTransformersForConditionalGeneration.from_pretrained("/whole/path/to/google_switch-base-256/", local_files_only=True)
loading_end_time = time.time()
loading_time = loading_end_time - loading_start_time
print(f"Model loading time: {loading_time:.2f} seconds")
# Simple conversation memory
conversation_history = []
while True:
# Get user input
user_input = input("User: ")
# Add user input to conversation history
conversation_history.append(f"User: {user_input}")
# Combine conversation history into a single string
conversation_text = " ".join(conversation_history)
# Timer for generating the response
generation_start_time = time.time()
# Tokenize and generate response
input_ids = tokenizer(conversation_text, return_tensors="pt").input_ids
decoder_start_token_id = tokenizer.convert_tokens_to_ids("<pad>")
max_new_tokens = 2048 # Adjust the value as needed
outputs = model.generate(input_ids, decoder_start_token_id=decoder_start_token_id, max_new_tokens=max_new_tokens)
generation_end_time = time.time()
generation_time = generation_end_time - generation_start_time
# Print the generated output
generated_text = tokenizer.decode(outputs[0])
print(f"SwitchBot: {generated_text}")
print(f"Response generation time: {generation_time:.2f} seconds")
# Token counter
response_tokens = len(outputs[0])
tokens_per_second = response_tokens / generation_time
print(f"Number of response tokens: {response_tokens}")
print(f"Tokens per second: {tokens_per_second:.2f}")
# Add SwitchBot's response to conversation history
conversation_history.append(f"SwitchBot: {generated_text}")
Note: I downloaded all the files to a local directory and ran the script from that same directory. You will need to edit this script to tell it that same whole directory path.
Thanks for sharing @JuLuComputing !
Thanks @JuLuComputing !
You bet, peeps! π I'm glad to help the community. Let me know if there are any questions on the script.