from transformers import AutoTokenizer, AutoModelForCausalLM import torch import os os.environ["TOKENIZERS_PARALLELISM"] = "false" # To prevent long warnings :) tokenizer = AutoTokenizer.from_pretrained("gg-hf/gemma-2-2b-it") model = AutoModelForCausalLM.from_pretrained("gg-hf/gemma-2-2b-it").to("cuda:1") model.generation_config.cache_implementation = "static" model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) messages = [ {"role": "user", "content": "Who are you? Please, answer in pirate-speak."}, ] inputs = tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True, ).to("cuda:1") outputs = model.generate(**inputs, max_new_tokens=256) print(tokenizer.batch_decode(outputs, skip_special_tokens=True))