|
--- |
|
license: apache-2.0 |
|
inference: false |
|
tags: |
|
- auto-gptq |
|
pipeline_tag: text-generation |
|
--- |
|
|
|
|
|
# redpajama gptq: RedPajama-INCITE-Chat-3B-v1 |
|
|
|
<a href="https://colab.research.google.com/gist/pszemraj/86d2e8485df182302646ed2c5a637059/inference-with-redpajama-incite-chat-3b-v1-gptq-4bit-128g.ipynb"> |
|
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/> |
|
</a> |
|
|
|
A GPTQ quantization of the [RedPajama-INCITE-Chat-3B-v1](https://huggingface.co/togethercomputer/RedPajama-INCITE-Chat-3B-v1) via auto-gptq. Model file is only 2GB. |
|
|
|
|
|
## Usage |
|
|
|
|
|
> Note that you cannot load directly from the hub with `auto_gptq` yet - if needed you can use [this function](https://gist.github.com/pszemraj/8368cba3400bda6879e521a55d2346d0) to download using the repo name. |
|
|
|
|
|
first install auto-GPTQ |
|
|
|
```bash |
|
pip install ninja auto-gptq[triton] |
|
``` |
|
|
|
load: |
|
|
|
```python |
|
import torch |
|
from pathlib import Path |
|
from auto_gptq import AutoGPTQForCausalLM |
|
from transformers import AutoTokenizer |
|
|
|
model_repo = Path.cwd() / "RedPajama-INCITE-Chat-3B-v1-GPTQ-4bit-128g" |
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
tokenizer = AutoTokenizer.from_pretrained(model_repo) |
|
model = AutoGPTQForCausalLM.from_quantized( |
|
model_repo, |
|
device=device, |
|
use_safetensors=True, |
|
use_triton=device != "cpu", # comment/remove if not on Linux |
|
).to(device) |
|
``` |
|
|
|
Inference: |
|
|
|
```python |
|
import re |
|
import pprint as pp |
|
|
|
|
|
prompt = "How can I further strive to increase shareholder value even further?" |
|
prompt = f"<human>: {prompt}\n<bot>:" |
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
|
outputs = model.generate( |
|
**inputs, |
|
penalty_alpha=0.6, |
|
top_k=4, |
|
temperature=0.7, |
|
do_sample=True, |
|
max_new_tokens=192, |
|
length_penalty=0.9, |
|
pad_token_id=model.config.eos_token_id |
|
) |
|
result = tokenizer.batch_decode( |
|
outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True |
|
) |
|
|
|
bot_responses = re.findall(r'<bot>:(.*?)(<human>|$)', result[0], re.DOTALL) |
|
bot_responses = [response[0].strip() for response in bot_responses] |
|
|
|
print(bot_responses[0]) |
|
``` |