reach-vb HF staff commited on
Commit
98b0855
1 Parent(s): 900797c

Create test-tok-sec.py (#5)

Browse files

- Create test-tok-sec.py (973c6f612c932a9cec506428de5af10f092709f1)

Files changed (1) hide show
  1. test-tok-sec.py +52 -0
test-tok-sec.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
2
+ import torch
3
+ from tqdm import tqdm
4
+
5
+ # Other configuration options
6
+ DEVICE = "cuda:1"
7
+ NUM_RUNS = 10
8
+ MAX_NEW_TOKENS = 1000
9
+ TEXT_INPUT = "def sieve_of_eratosthenes():"
10
+
11
+ # Load the model and prepare generate args
12
+ repo_id = "gg-hf/gemma-2-2b-it"
13
+ model = AutoModelForCausalLM.from_pretrained(repo_id).to(DEVICE)
14
+ # model = AutoModelForCausalLM.from_pretrained(repo_id, device_map="auto", torch_dtype=torch.bfloat16)
15
+
16
+ assistant_model = None
17
+ tokenizer = AutoTokenizer.from_pretrained(repo_id, use_fast=True)
18
+ model_inputs = tokenizer(TEXT_INPUT, return_tensors="pt").to(DEVICE)
19
+
20
+ generate_kwargs = {
21
+ "max_new_tokens": MAX_NEW_TOKENS,
22
+ "do_sample": True,
23
+ "temperature": 0.2,
24
+ "eos_token_id": -1 # forces the generation of `max_new_tokens`
25
+ }
26
+
27
+ # Warmup
28
+ print("Warming up...")
29
+ for _ in range(2):
30
+ gen_out = model.generate(**model_inputs, **generate_kwargs)
31
+ print("Done!")
32
+
33
+
34
+ # Measure OR Stream
35
+ def measure_generate(model, model_inputs, generate_kwargs):
36
+ start_event = torch.cuda.Event(enable_timing=True)
37
+ end_event = torch.cuda.Event(enable_timing=True)
38
+ torch.cuda.reset_peak_memory_stats(DEVICE)
39
+ torch.cuda.empty_cache()
40
+ torch.cuda.synchronize()
41
+
42
+ start_event.record()
43
+ for _ in tqdm(range(NUM_RUNS)):
44
+ gen_out = model.generate(**model_inputs, **generate_kwargs)
45
+ end_event.record()
46
+
47
+ torch.cuda.synchronize()
48
+ max_memory = torch.cuda.max_memory_allocated(DEVICE)
49
+ print("Max memory (MB): ", max_memory * 1e-6)
50
+ print("Throughput (tokens/sec): ", (NUM_RUNS * MAX_NEW_TOKENS) / (start_event.elapsed_time(end_event) * 1.0e-3))
51
+
52
+ measure_generate(model, model_inputs, generate_kwargs)