File size: 5,565 Bytes
2168a7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db9ef4d
2168a7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db9ef4d
62b1949
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
---
license: apache-2.0
train: false
inference: false
pipeline_tag: text-generation
---
This is an <a href="https://github.com/mobiusml/hqq/">HQQ</a> all 4-bit (group-size=64) quantized <a href="https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1">Mixtral-8x7B-Instruct-v0.1</a> model.


## Usage
First, install the dependecies:
```
pip install git+https://github.com/mobiusml/hqq.git;
pip install git+https://github.com/mobiusml/gemlite.git; #to use the gemlite backend
pip install bitblas #to use the bitblas backend
```

Then you can use the sample code below:
## Transformers 🤗
``` Python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from hqq.utils.patching import *
from hqq.core.quantize import *
from hqq.utils.generation_hf import patch_model_for_compiled_runtime

#Settings
###################################################
backend       = "gemlite" #"torchao_int4" (4-bit only) or "bitblas" (4-bit + 2-bit) or "gemlite" (8-bit, 4-bit, 2-bit, 1-bit)
compute_dtype = torch.bfloat16 if backend=="torchao_int4" else torch.float16
device        = 'cuda:0'
cache_dir     = '.'
model_id      = "mobiuslabsgmbh/Mixtral-8x7B-Instruct-v0.1_4bitgs64_hqq_hf"

model     = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=compute_dtype, cache_dir=cache_dir, device_map=device, attn_implementation="sdpa")
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_dir)

#Use optimized inference kernels
########################################################################
prepare_for_inference(model, backend=backend, verbose=True) 

#Load gemlite cache for faster warm-up
if(backend == 'gemlite'):
    import gemlite
    gemlite.core.GemLiteLinear.load_config('gemlite_config.json')

#Generate
########################################################################
from hqq.utils.generation_hf import HFGenerator
#Mixtral doesn't support cuda graphs with HF unfortuantely...
#gen = HFGenerator(model, tokenizer, max_new_tokens=1000, do_sample=True, compile=None)

gen = HFGenerator(model, tokenizer, max_new_tokens=1000, do_sample=True, compile="partial", 
                                    compile_options={"mode": "max-autotune-no-cudagraphs"} 
                                    )#.enable_cuda_graph()

gen.generate("Write an essay about large language models", print_tokens=True)

########################################################################
# #Inference with model,generate()
# from hqq.utils.generation_hf import patch_model_for_compiled_runtime

# patch_model_for_compiled_runtime(model, tokenizer, pre_compile=False) 

# prompt  = "Write an essay about large language models."
# inputs  = tokenizer.apply_chat_template([{"role":"user", "content":prompt}], tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True)
# outputs = model.generate(**inputs.to(model.device), max_new_tokens=1000, cache_implementation="static", pad_token_id=tokenizer.pad_token_id) 
# #print(tokenizer.decode(outputs[0])

########################################################################
#Save gemlite cache
if(backend == 'gemlite'):
    gemlite.core.GemLiteLinear.cache_config('/tmp/gemlite_config.json') 
```

## VLLM
Run with <a href="https://github.com/vllm-project/vllm/">vllm</a>:
```Python
##################################################################
import torch
import torch.nn as nn
from typing import Optional
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
class MixtralMLPRowParallel(nn.Module):

    def __init__(
        self,
        num_experts: int,
        hidden_size: int,
        intermediate_size: int,
        quant_config: Optional[QuantizationConfig] = None,
    ) -> None:
        super().__init__()
        self.num_experts = num_experts
        self.ffn_dim = intermediate_size
        self.hidden_dim = hidden_size

        self.w1 = RowParallelLinear(self.hidden_dim,
                                   self.ffn_dim,
                                   bias=False,
                                   quant_config=quant_config)
        self.w2 = RowParallelLinear(self.ffn_dim,
                                   self.hidden_dim,
                                   bias=False,
                                   quant_config=quant_config)
        self.w3 = RowParallelLinear(self.hidden_dim,
                                   self.ffn_dim,
                                   bias=False,
                                   quant_config=quant_config)

        # TODO: Use vllm's SiluAndMul
        self.act_fn = nn.SiLU()

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        w1_out, _ = self.w1(hidden_states)
        w1_out = self.act_fn(w1_out)
        w3_out, _ = self.w3(hidden_states)
        current_hidden_states = w1_out * w3_out
        current_hidden_states, _ = self.w2(current_hidden_states)
        return current_hidden_states

import vllm.model_executor.models.mixtral_quant as mixtral_quant
mixtral_quant.MixtralMLP = MixtralMLPRowParallel
##################################################################

from vllm import LLM
from vllm.sampling_params import SamplingParams
model_id = "mobiuslabsgmbh/Mixtral-8x7B-Instruct-v0.1_4bitgs64_hqq_hf"

llm = LLM(model=model_id, gpu_memory_utilization=0.80)
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=1024)
outputs = llm.generate(["What is the capital of Germany?"], sampling_params)
print(outputs[0].outputs[0].text)
```