Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -14,7 +14,7 @@ from functools import cached_property
|
|
14 |
import base64
|
15 |
from optimum.onnxruntime import ORTModelForCausalLM
|
16 |
from optimum.bettertransformer import BetterTransformer
|
17 |
-
|
18 |
|
19 |
AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
|
20 |
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
|
@@ -41,7 +41,6 @@ class GenerateRequest(BaseModel):
|
|
41 |
num_return_sequences: int = 1
|
42 |
do_sample: bool = False
|
43 |
stop_sequences: list[str] = []
|
44 |
-
quantize: bool = True
|
45 |
use_onnx: bool = False
|
46 |
use_bettertransformer: bool = True
|
47 |
@field_validator("model_name")
|
@@ -62,17 +61,12 @@ class S3ModelLoader:
|
|
62 |
self.model_cache = {}
|
63 |
def _get_s3_uri(self, model_name):
|
64 |
return f"s3://{self.bucket_name}/{model_name.replace('/', '-')}"
|
65 |
-
async def _load_model_and_tokenizer(self, model_name,
|
66 |
s3_uri = self._get_s3_uri(model_name)
|
67 |
try:
|
68 |
config = AutoConfig.from_pretrained(s3_uri, local_files_only=False)
|
69 |
if use_onnx:
|
70 |
model = ORTModelForCausalLM.from_pretrained(s3_uri, config=config, local_files_only=False).to(self.device)
|
71 |
-
elif quantize:
|
72 |
-
model = AutoModelForCausalLM.from_pretrained(
|
73 |
-
s3_uri, config=config, local_files_only=False,
|
74 |
-
load_in_8bit=True
|
75 |
-
).to(self.device)
|
76 |
else:
|
77 |
model = AutoModelForCausalLM.from_pretrained(s3_uri, config=config, local_files_only=False).to(self.device)
|
78 |
if use_bettertransformer:
|
@@ -87,11 +81,6 @@ class S3ModelLoader:
|
|
87 |
tokenizer = AutoTokenizer.from_pretrained(model_name, config=config, token=HUGGINGFACE_HUB_TOKEN)
|
88 |
if use_onnx:
|
89 |
model = ORTModelForCausalLM.from_pretrained(model_name, config=config, token=HUGGINGFACE_HUB_TOKEN).to(self.device)
|
90 |
-
elif quantize:
|
91 |
-
model = AutoModelForCausalLM.from_pretrained(
|
92 |
-
model_name, config=config, token=HUGGINGFACE_HUB_TOKEN,
|
93 |
-
load_in_8bit=True
|
94 |
-
).to(self.device)
|
95 |
else:
|
96 |
model = AutoModelForCausalLM.from_pretrained(model_name, config=config, token=HUGGINGFACE_HUB_TOKEN).to(self.device)
|
97 |
if use_bettertransformer:
|
@@ -104,10 +93,10 @@ class S3ModelLoader:
|
|
104 |
@cached_property
|
105 |
def device(self):
|
106 |
return torch.device("cpu")
|
107 |
-
async def get_model_and_tokenizer(self, model_name,
|
108 |
-
key = f"{model_name}-{
|
109 |
if key not in self.model_cache:
|
110 |
-
model, tokenizer = await self._load_model_and_tokenizer(model_name,
|
111 |
self.model_cache[key] = {"model":model, "tokenizer":tokenizer}
|
112 |
return self.model_cache[key]["model"], self.model_cache[key]["tokenizer"]
|
113 |
async def get_pipeline(self, model_name, task_type):
|
@@ -133,10 +122,9 @@ async def generate(request: GenerateRequest):
|
|
133 |
num_return_sequences = request.num_return_sequences
|
134 |
do_sample = request.do_sample
|
135 |
stop_sequences = request.stop_sequences
|
136 |
-
quantize = request.quantize
|
137 |
use_onnx = request.use_onnx
|
138 |
use_bettertransformer = request.use_bettertransformer
|
139 |
-
model, tokenizer = await model_loader.get_model_and_tokenizer(model_name,
|
140 |
if "text-to-text" == task_type:
|
141 |
generation_config = GenerationConfig(temperature=temperature,max_new_tokens=max_new_tokens,top_p=top_p,top_k=top_k,repetition_penalty=repetition_penalty,do_sample=do_sample,num_return_sequences=num_return_sequences,eos_token_id = tokenizer.eos_token_id)
|
142 |
if stream:
|
|
|
14 |
import base64
|
15 |
from optimum.onnxruntime import ORTModelForCausalLM
|
16 |
from optimum.bettertransformer import BetterTransformer
|
17 |
+
|
18 |
|
19 |
AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
|
20 |
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
|
|
|
41 |
num_return_sequences: int = 1
|
42 |
do_sample: bool = False
|
43 |
stop_sequences: list[str] = []
|
|
|
44 |
use_onnx: bool = False
|
45 |
use_bettertransformer: bool = True
|
46 |
@field_validator("model_name")
|
|
|
61 |
self.model_cache = {}
|
62 |
def _get_s3_uri(self, model_name):
|
63 |
return f"s3://{self.bucket_name}/{model_name.replace('/', '-')}"
|
64 |
+
async def _load_model_and_tokenizer(self, model_name, use_onnx, use_bettertransformer):
|
65 |
s3_uri = self._get_s3_uri(model_name)
|
66 |
try:
|
67 |
config = AutoConfig.from_pretrained(s3_uri, local_files_only=False)
|
68 |
if use_onnx:
|
69 |
model = ORTModelForCausalLM.from_pretrained(s3_uri, config=config, local_files_only=False).to(self.device)
|
|
|
|
|
|
|
|
|
|
|
70 |
else:
|
71 |
model = AutoModelForCausalLM.from_pretrained(s3_uri, config=config, local_files_only=False).to(self.device)
|
72 |
if use_bettertransformer:
|
|
|
81 |
tokenizer = AutoTokenizer.from_pretrained(model_name, config=config, token=HUGGINGFACE_HUB_TOKEN)
|
82 |
if use_onnx:
|
83 |
model = ORTModelForCausalLM.from_pretrained(model_name, config=config, token=HUGGINGFACE_HUB_TOKEN).to(self.device)
|
|
|
|
|
|
|
|
|
|
|
84 |
else:
|
85 |
model = AutoModelForCausalLM.from_pretrained(model_name, config=config, token=HUGGINGFACE_HUB_TOKEN).to(self.device)
|
86 |
if use_bettertransformer:
|
|
|
93 |
@cached_property
|
94 |
def device(self):
|
95 |
return torch.device("cpu")
|
96 |
+
async def get_model_and_tokenizer(self, model_name, use_onnx, use_bettertransformer):
|
97 |
+
key = f"{model_name}-{use_onnx}-{use_bettertransformer}"
|
98 |
if key not in self.model_cache:
|
99 |
+
model, tokenizer = await self._load_model_and_tokenizer(model_name, use_onnx, use_bettertransformer)
|
100 |
self.model_cache[key] = {"model":model, "tokenizer":tokenizer}
|
101 |
return self.model_cache[key]["model"], self.model_cache[key]["tokenizer"]
|
102 |
async def get_pipeline(self, model_name, task_type):
|
|
|
122 |
num_return_sequences = request.num_return_sequences
|
123 |
do_sample = request.do_sample
|
124 |
stop_sequences = request.stop_sequences
|
|
|
125 |
use_onnx = request.use_onnx
|
126 |
use_bettertransformer = request.use_bettertransformer
|
127 |
+
model, tokenizer = await model_loader.get_model_and_tokenizer(model_name, use_onnx, use_bettertransformer)
|
128 |
if "text-to-text" == task_type:
|
129 |
generation_config = GenerationConfig(temperature=temperature,max_new_tokens=max_new_tokens,top_p=top_p,top_k=top_k,repetition_penalty=repetition_penalty,do_sample=do_sample,num_return_sequences=num_return_sequences,eos_token_id = tokenizer.eos_token_id)
|
130 |
if stream:
|