Vasily Alexeev commited on
Commit
e1b1ab8
·
1 Parent(s): f548e70

add examples

Browse files
Files changed (1) hide show
  1. README.md +171 -0
README.md CHANGED
@@ -27,3 +27,174 @@ Quantized with [OmniQuant](https://github.com/OpenGVLab/OmniQuant).
27
 
28
  ## Examples
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  ## Examples
29
 
30
+ ### Imports and Model Loading
31
+
32
+ <details>
33
+ <summary>Expand</summary>
34
+
35
+ ```python
36
+ import gc
37
+
38
+ import auto_gptq.nn_modules.qlinear.qlinear_cuda as qlinear_cuda
39
+ import auto_gptq.nn_modules.qlinear.qlinear_triton as qlinear_triton
40
+ import torch
41
+
42
+ from accelerate import (
43
+ init_empty_weights,
44
+ infer_auto_device_map,
45
+ load_checkpoint_in_model,
46
+ )
47
+ from tqdm import tqdm
48
+ from transformers import (
49
+ AutoConfig,
50
+ AutoModelForCausalLM,
51
+ AutoTokenizer,
52
+ pipeline,
53
+ )
54
+
55
+
56
+ def get_named_linears(model):
57
+ return {
58
+ name: module for name, module in model.named_modules()
59
+ if isinstance(module, torch.nn.Linear)
60
+ }
61
+
62
+
63
+ def set_module(model, name, module):
64
+ parent = model
65
+ levels = name.split('.')
66
+
67
+ for i in range(len(levels) - 1):
68
+ cur_name = levels[i]
69
+
70
+ if cur_name.isdigit():
71
+ parent = parent[int(cur_name)]
72
+ else:
73
+ parent = getattr(parent, cur_name)
74
+
75
+ setattr(parent, levels[-1], module)
76
+
77
+
78
+ def load_model(model_path):
79
+ # Based on: https://github.com/OpenGVLab/OmniQuant/blob/main/runing_quantized_mixtral_7bx8.ipynb
80
+
81
+ config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
82
+
83
+ if not hasattr(config, 'quantization_config'):
84
+ raise AttributeError(
85
+ f'No quantization info found in model config "{model_path}"'
86
+ f' (`quantization_config` section is missing).'
87
+ )
88
+
89
+ wbits = config.quantization_config['bits']
90
+ group_size = config.quantization_config['group_size']
91
+
92
+ # We are going to init an ordinary model and then manually replace all Linears with QuantLinears
93
+ del config.quantization_config
94
+
95
+ with init_empty_weights():
96
+ model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch.float16, trust_remote_code=True)
97
+
98
+ layers = model.model.layers
99
+
100
+ for i in tqdm(range(len(layers))):
101
+ layer = layers[i]
102
+ named_linears = get_named_linears(layer)
103
+
104
+ for name, module in named_linears.items():
105
+ params = (
106
+ wbits, group_size,
107
+ module.in_features, module.out_features,
108
+ module.bias is not None
109
+ )
110
+
111
+ if wbits in [2, 4]:
112
+ q_linear = qlinear_triton.QuantLinear(*params)
113
+ elif wbits == 3:
114
+ q_linear = qlinear_cuda.QuantLinear(*params)
115
+ else:
116
+ raise NotImplementedError("Only 2, 3 and 4 bits are supported.")
117
+
118
+ q_linear.to(next(layer.parameters()).device)
119
+ set_module(layer, name, q_linear)
120
+
121
+ torch.cuda.empty_cache()
122
+ gc.collect()
123
+
124
+ model.tie_weights()
125
+ device_map = infer_auto_device_map(model)
126
+
127
+ print("Loading pre-computed quantized weights...")
128
+
129
+ load_checkpoint_in_model(
130
+ model, checkpoint=model_path,
131
+ device_map=device_map, offload_state_dict=True,
132
+ )
133
+
134
+ print("Model loaded successfully!")
135
+
136
+ return model
137
+ ```
138
+ </details>
139
+
140
+
141
+ ### Inference
142
+
143
+ ```python
144
+ model_path = "compressa-ai/Saiga-Llama-3-8B-OmniQuant"
145
+
146
+ model = load_model(model_path)
147
+ model.cuda()
148
+ tokenizer = AutoTokenizer.from_pretrained(
149
+ model_path, use_fast=False, trust_remote_code=True
150
+ )
151
+
152
+ system_message = "Ты — дружелюбный чат-бот, который всегда отвечает как пират."
153
+ user_message = "Куда мы направляемся, капитан?"
154
+ messages = [
155
+ {"role": "system", "content": system_message},
156
+ {"role": "user", "content": user_message},
157
+ ]
158
+ prompt = tokenizer.apply_chat_template(
159
+ messages, tokenize=False, add_generation_prompt=True
160
+ )
161
+
162
+ inputs = tokenizer(prompt, return_tensors="pt")
163
+ inputs = {k: v.cuda() for k, v in inputs.items()}
164
+
165
+ outputs = model.generate(
166
+ **inputs, max_new_tokens=512,
167
+ do_sample=True, temperature=0.7, top_p=0.95,
168
+ )
169
+
170
+ response = tokenizer.decode(outputs[0]) # , skip_special_tokens=True)
171
+ continuation = response.removeprefix(prompt).removesuffix(tokenizer.eos_token)
172
+
173
+ print(f'Prompt:\n{prompt}')
174
+ print(f'Continuation:\n{continuation}\n')
175
+ ```
176
+
177
+
178
+ ### Inference Using Pipeline
179
+
180
+ ```python
181
+ pipe = pipeline(
182
+ "text-generation",
183
+ model=model, tokenizer=tokenizer,
184
+ max_new_tokens=512, do_sample=True,
185
+ temperature=0.7, top_p=0.95,
186
+ device=0,
187
+ )
188
+
189
+ prompt = pipe.tokenizer.apply_chat_template(
190
+ messages, tokenize=False, add_generation_prompt=True
191
+ )
192
+
193
+ outputs = pipe(prompt)
194
+
195
+ response = outputs[0]["generated_text"]
196
+ continuation = response.removeprefix(prompt)
197
+
198
+ print(f'Prompt:\n{prompt}')
199
+ print(f'Continuation:\n{continuation}\n')
200
+ ```