Hjgugugjhuhjggg commited on
Commit
eec8624
·
verified ·
1 Parent(s): 0edd18a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -2
app.py CHANGED
@@ -14,6 +14,7 @@ from functools import cached_property
14
  import base64
15
  from optimum.onnxruntime import ORTModelForCausalLM
16
  from optimum.bettertransformer import BetterTransformer
 
17
 
18
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
19
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
@@ -68,7 +69,10 @@ class S3ModelLoader:
68
  if use_onnx:
69
  model = ORTModelForCausalLM.from_pretrained(s3_uri, config=config, local_files_only=False).to(self.device)
70
  elif quantize:
71
- model = AutoModelForCausalLM.from_pretrained(s3_uri, config=config, local_files_only=False, torch_dtype=torch.int8 if quantize else torch.float16).to(self.device)
 
 
 
72
  else:
73
  model = AutoModelForCausalLM.from_pretrained(s3_uri, config=config, local_files_only=False).to(self.device)
74
  if use_bettertransformer:
@@ -84,7 +88,10 @@ class S3ModelLoader:
84
  if use_onnx:
85
  model = ORTModelForCausalLM.from_pretrained(model_name, config=config, token=HUGGINGFACE_HUB_TOKEN).to(self.device)
86
  elif quantize:
87
- model = AutoModelForCausalLM.from_pretrained(model_name, config=config, token=HUGGINGFACE_HUB_TOKEN, torch_dtype=torch.int8 if quantize else torch.float16).to(self.device)
 
 
 
88
  else:
89
  model = AutoModelForCausalLM.from_pretrained(model_name, config=config, token=HUGGINGFACE_HUB_TOKEN).to(self.device)
90
  if use_bettertransformer:
 
14
  import base64
15
  from optimum.onnxruntime import ORTModelForCausalLM
16
  from optimum.bettertransformer import BetterTransformer
17
+ import bitsandbytes as bnb
18
 
19
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
20
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
 
69
  if use_onnx:
70
  model = ORTModelForCausalLM.from_pretrained(s3_uri, config=config, local_files_only=False).to(self.device)
71
  elif quantize:
72
+ model = AutoModelForCausalLM.from_pretrained(
73
+ s3_uri, config=config, local_files_only=False,
74
+ load_in_8bit=True
75
+ ).to(self.device)
76
  else:
77
  model = AutoModelForCausalLM.from_pretrained(s3_uri, config=config, local_files_only=False).to(self.device)
78
  if use_bettertransformer:
 
88
  if use_onnx:
89
  model = ORTModelForCausalLM.from_pretrained(model_name, config=config, token=HUGGINGFACE_HUB_TOKEN).to(self.device)
90
  elif quantize:
91
+ model = AutoModelForCausalLM.from_pretrained(
92
+ model_name, config=config, token=HUGGINGFACE_HUB_TOKEN,
93
+ load_in_8bit=True
94
+ ).to(self.device)
95
  else:
96
  model = AutoModelForCausalLM.from_pretrained(model_name, config=config, token=HUGGINGFACE_HUB_TOKEN).to(self.device)
97
  if use_bettertransformer: