mobicham commited on
Commit
62b1949
·
verified ·
1 Parent(s): 2168a7d

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +59 -0
README.md CHANGED
@@ -72,3 +72,62 @@ if(backend == 'gemlite'):
72
  gemlite.core.GemLiteLinear.cache_config('/tmp/gemlite_config.json')
73
  ```
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  gemlite.core.GemLiteLinear.cache_config('/tmp/gemlite_config.json')
73
  ```
74
 
75
+ Run with <a href="https://github.com/vllm-project/vllm/">vllm</a>:
76
+ ```Python
77
+ ##################################################################
78
+ import torch
79
+ import torch.nn as nn
80
+ from typing import Optional
81
+ from vllm.model_executor.layers.linear import RowParallelLinear
82
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
83
+ class MixtralMLPRowParallel(nn.Module):
84
+
85
+ def __init__(
86
+ self,
87
+ num_experts: int,
88
+ hidden_size: int,
89
+ intermediate_size: int,
90
+ quant_config: Optional[QuantizationConfig] = None,
91
+ ) -> None:
92
+ super().__init__()
93
+ self.num_experts = num_experts
94
+ self.ffn_dim = intermediate_size
95
+ self.hidden_dim = hidden_size
96
+
97
+ self.w1 = RowParallelLinear(self.hidden_dim,
98
+ self.ffn_dim,
99
+ bias=False,
100
+ quant_config=quant_config)
101
+ self.w2 = RowParallelLinear(self.ffn_dim,
102
+ self.hidden_dim,
103
+ bias=False,
104
+ quant_config=quant_config)
105
+ self.w3 = RowParallelLinear(self.hidden_dim,
106
+ self.ffn_dim,
107
+ bias=False,
108
+ quant_config=quant_config)
109
+
110
+ # TODO: Use vllm's SiluAndMul
111
+ self.act_fn = nn.SiLU()
112
+
113
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
114
+ w1_out, _ = self.w1(hidden_states)
115
+ w1_out = self.act_fn(w1_out)
116
+ w3_out, _ = self.w3(hidden_states)
117
+ current_hidden_states = w1_out * w3_out
118
+ current_hidden_states, _ = self.w2(current_hidden_states)
119
+ return current_hidden_states
120
+
121
+ import vllm.model_executor.models.mixtral_quant as mixtral_quant
122
+ mixtral_quant.MixtralMLP = MixtralMLPRowParallel
123
+ ##################################################################
124
+
125
+ from vllm import LLM
126
+ from vllm.sampling_params import SamplingParams
127
+ model_id = "mobiuslabsgmbh/Mixtral-8x7B-Instruct-v0.1_4bitgs64_hqq_hf"
128
+
129
+ llm = LLM(model=model_id, gpu_memory_utilization=0.80)
130
+ sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=1024)
131
+ outputs = llm.generate(["What is the capital of Germany?"], sampling_params)
132
+ print(outputs[0].outputs[0].text)
133
+ ```