Spaces:
Sleeping
Sleeping
Hjgugugjhuhjggg
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -6,8 +6,7 @@ from pydantic import BaseModel, field_validator
|
|
6 |
from transformers import (
|
7 |
AutoConfig,
|
8 |
pipeline,
|
9 |
-
|
10 |
-
AutoModelForCausalLM,
|
11 |
AutoTokenizer,
|
12 |
GenerationConfig,
|
13 |
StoppingCriteriaList
|
@@ -83,18 +82,10 @@ class S3ModelLoader:
|
|
83 |
s3_uri, local_files_only=True
|
84 |
)
|
85 |
|
86 |
-
|
87 |
-
model = AutoModelForCausalLM.from_pretrained(
|
88 |
-
s3_uri, config=config, local_files_only=True, rope_scaling = {"type": "linear", "factor": 2.0}
|
89 |
-
)
|
90 |
-
elif 'qwen' in model_name:
|
91 |
-
model = AutoModelForCausalLM.from_pretrained(
|
92 |
-
s3_uri, config=config, local_files_only=True
|
93 |
-
)
|
94 |
-
else:
|
95 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(
|
96 |
s3_uri, config=config, local_files_only=True
|
97 |
)
|
|
|
98 |
|
99 |
tokenizer = AutoTokenizer.from_pretrained(
|
100 |
s3_uri, config=config, local_files_only=True
|
@@ -115,19 +106,11 @@ class S3ModelLoader:
|
|
115 |
model_name, config=config, token=HUGGINGFACE_HUB_TOKEN
|
116 |
)
|
117 |
|
118 |
-
|
119 |
-
model = AutoModelForCausalLM.from_pretrained(
|
120 |
-
model_name, config=config, token=HUGGINGFACE_HUB_TOKEN, rope_scaling = {"type": "linear", "factor": 2.0}
|
121 |
-
)
|
122 |
-
elif 'qwen' in model_name:
|
123 |
-
model = AutoModelForCausalLM.from_pretrained(
|
124 |
-
model_name, config=config, token=HUGGINGFACE_HUB_TOKEN
|
125 |
-
)
|
126 |
-
else:
|
127 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(
|
128 |
model_name, config=config, token=HUGGINGFACE_HUB_TOKEN
|
129 |
)
|
130 |
|
|
|
131 |
if tokenizer.eos_token_id is not None and \
|
132 |
tokenizer.pad_token_id is None:
|
133 |
tokenizer.pad_token_id = config.pad_token_id \
|
@@ -164,23 +147,26 @@ async def generate(request: GenerateRequest):
|
|
164 |
load_model_and_tokenizer(model_name)
|
165 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
166 |
model.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
)
|
177 |
-
|
178 |
-
return StreamingResponse(
|
179 |
-
stream_text(model, tokenizer, input_text,
|
180 |
-
generation_config, stop_sequences,
|
181 |
-
device, chunk_delay),
|
182 |
-
media_type="text/plain"
|
183 |
-
)
|
184 |
|
185 |
except Exception as e:
|
186 |
raise HTTPException(
|
|
|
6 |
from transformers import (
|
7 |
AutoConfig,
|
8 |
pipeline,
|
9 |
+
AutoModel,
|
|
|
10 |
AutoTokenizer,
|
11 |
GenerationConfig,
|
12 |
StoppingCriteriaList
|
|
|
82 |
s3_uri, local_files_only=True
|
83 |
)
|
84 |
|
85 |
+
model = AutoModel.from_pretrained(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
s3_uri, config=config, local_files_only=True
|
87 |
)
|
88 |
+
|
89 |
|
90 |
tokenizer = AutoTokenizer.from_pretrained(
|
91 |
s3_uri, config=config, local_files_only=True
|
|
|
106 |
model_name, config=config, token=HUGGINGFACE_HUB_TOKEN
|
107 |
)
|
108 |
|
109 |
+
model = AutoModel.from_pretrained(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
model_name, config=config, token=HUGGINGFACE_HUB_TOKEN
|
111 |
)
|
112 |
|
113 |
+
|
114 |
if tokenizer.eos_token_id is not None and \
|
115 |
tokenizer.pad_token_id is None:
|
116 |
tokenizer.pad_token_id = config.pad_token_id \
|
|
|
147 |
load_model_and_tokenizer(model_name)
|
148 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
149 |
model.to(device)
|
150 |
+
|
151 |
+
if "text-to-text" == task_type:
|
152 |
+
generation_config = GenerationConfig(
|
153 |
+
temperature=temperature,
|
154 |
+
max_new_tokens=max_new_tokens,
|
155 |
+
top_p=top_p,
|
156 |
+
top_k=top_k,
|
157 |
+
repetition_penalty=repetition_penalty,
|
158 |
+
do_sample=do_sample,
|
159 |
+
num_return_sequences=num_return_sequences,
|
160 |
+
)
|
161 |
|
162 |
+
return StreamingResponse(
|
163 |
+
stream_text(model, tokenizer, input_text,
|
164 |
+
generation_config, stop_sequences,
|
165 |
+
device, chunk_delay),
|
166 |
+
media_type="text/plain"
|
167 |
+
)
|
168 |
+
else:
|
169 |
+
return HTTPException(status_code=400, detail="Task type not text-to-text")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
|
171 |
except Exception as e:
|
172 |
raise HTTPException(
|