|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import torch |
|
import os |
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("gg-hf/gemma-2-2b-it") |
|
model = AutoModelForCausalLM.from_pretrained("gg-hf/gemma-2-2b-it").to("cuda:1") |
|
|
|
model.generation_config.cache_implementation = "static" |
|
|
|
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) |
|
messages = [ |
|
{"role": "user", "content": "Who are you? Please, answer in pirate-speak."}, |
|
] |
|
|
|
inputs = tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=True, |
|
add_generation_prompt=True, |
|
return_tensors="pt", |
|
return_dict=True, |
|
).to("cuda:1") |
|
|
|
outputs = model.generate(**inputs, max_new_tokens=256) |
|
print(tokenizer.batch_decode(outputs, skip_special_tokens=True)) |