reach-vb HF staff commited on
Commit
a642e69
1 Parent(s): 98b0855

Create test-compile.py (#6)

Browse files

- Create test-compile.py (8101093a7bbf9d54eea9c8791427c5463a1d2d2b)

Files changed (1) hide show
  1. test-compile.py +25 -0
test-compile.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
2
+ import torch
3
+ import os
4
+ os.environ["TOKENIZERS_PARALLELISM"] = "false" # To prevent long warnings :)
5
+
6
+ tokenizer = AutoTokenizer.from_pretrained("gg-hf/gemma-2-2b-it")
7
+ model = AutoModelForCausalLM.from_pretrained("gg-hf/gemma-2-2b-it").to("cuda:1")
8
+
9
+ model.generation_config.cache_implementation = "static"
10
+
11
+ model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
12
+ messages = [
13
+ {"role": "user", "content": "Who are you? Please, answer in pirate-speak."},
14
+ ]
15
+
16
+ inputs = tokenizer.apply_chat_template(
17
+ messages,
18
+ tokenize=True,
19
+ add_generation_prompt=True,
20
+ return_tensors="pt",
21
+ return_dict=True,
22
+ ).to("cuda:1")
23
+
24
+ outputs = model.generate(**inputs, max_new_tokens=256)
25
+ print(tokenizer.batch_decode(outputs, skip_special_tokens=True))