mobicham commited on
Commit
1e4dbbc
·
verified ·
1 Parent(s): 66f3e32

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +9 -10
README.md CHANGED
@@ -60,25 +60,24 @@ from hqq.utils.patching import *
60
  from hqq.core.quantize import *
61
  from hqq.utils.generation_hf import HFGenerator
62
 
 
 
 
 
 
 
 
63
  #Load the model
64
  ###################################################
65
  #model_id = 'mobiuslabsgmbh/Llama-3.1-8b-instruct_4bitgs64_hqq' #no calib version
66
  model_id = 'mobiuslabsgmbh/Llama-3.1-8b-instruct_4bitgs64_hqq_calib' #calibrated version
67
 
68
- compute_dtype = torch.bfloat16 #bfloat16 for torchao_int4, float16 for bitblas
69
- cache_dir = '.'
70
- model = AutoHQQHFModel.from_quantized(model_id, cache_dir=cache_dir, compute_dtype=compute_dtype)
71
  tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_dir)
72
 
73
- quant_config = BaseQuantizeConfig(nbits=4, group_size=64, quant_scale=False, quant_zero=False, axis=1)
74
- patch_linearlayers(model, patch_add_quant_config, quant_config)
75
-
76
  #Use optimized inference kernels
77
  ###################################################
78
- HQQLinear.set_backend(HQQBackend.PYTORCH)
79
- #prepare_for_inference(model) #default backend
80
- prepare_for_inference(model, backend="torchao_int4")
81
- #prepare_for_inference(model, backend="bitblas") #takes a while to init...
82
 
83
  #Generate
84
  ###################################################
 
60
  from hqq.core.quantize import *
61
  from hqq.utils.generation_hf import HFGenerator
62
 
63
+ #Settings
64
+ ###################################################
65
+ backend = "torchao_int4" #'torchao_int4' #"torchao_int4" (4-bit only) or "bitblas" (4-bit + 2-bit) or "gemlite" (8-bit, 4-bit, 2-bit, 1-bit)
66
+ compute_dtype = torch.bfloat16 if backend=="torchao_int4" else torch.float16
67
+ device = 'cuda:0'
68
+ cache_dir = '.'
69
+
70
  #Load the model
71
  ###################################################
72
  #model_id = 'mobiuslabsgmbh/Llama-3.1-8b-instruct_4bitgs64_hqq' #no calib version
73
  model_id = 'mobiuslabsgmbh/Llama-3.1-8b-instruct_4bitgs64_hqq_calib' #calibrated version
74
 
75
+ model = AutoHQQHFModel.from_quantized(model_id, cache_dir=cache_dir, compute_dtype=compute_dtype, device=device).eval()
 
 
76
  tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_dir)
77
 
 
 
 
78
  #Use optimized inference kernels
79
  ###################################################
80
+ prepare_for_inference(model, backend=backend)
 
 
 
81
 
82
  #Generate
83
  ###################################################