agarkovv commited on
Commit
97fdf26
·
verified ·
1 Parent(s): 74aabba

Create run_inference.py

Browse files
Files changed (1) hide show
  1. run_inference.py +31 -0
run_inference.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from peft import AutoPeftModelForCausalLM
2
+ from transformers import AutoTokenizer
3
+ import re
4
+
5
+
6
+ PROMPT = "YOUR PROMPT HERE"
7
+ MAX_LENGTH = 32768 # Do not change
8
+ DEVICE = "cuda"
9
+
10
+
11
+ model_id = "agarkovv/Ministral-8B-Instruct-2410-LoRA-trading"
12
+ base_model_id = "mistralai/Ministral-8B-Instruct-2410"
13
+
14
+ model = AutoPeftModelForCausalLM.from_pretrained(model_id)
15
+ tokenizer = AutoTokenizer.from_pretrained(base_model_id)
16
+
17
+ model = model.to(DEVICE)
18
+ model.eval()
19
+ inputs = tokenizer(
20
+ PROMPT, return_tensors="pt", padding=False, max_length=MAX_LENGTH, truncation=True
21
+ )
22
+ inputs = {key: value.to(model.device) for key, value in inputs.items()}
23
+
24
+ res = model.generate(
25
+ **inputs,
26
+ use_cache=True,
27
+ max_new_tokens=MAX_LENGTH,
28
+ )
29
+ output = tokenizer.decode(res[0], skip_special_tokens=True)
30
+ answer = re.sub(r".*\[/INST\]\s*", "", output, flags=re.DOTALL)
31
+ print(answer)