Spaces:
Sleeping
Sleeping
Hjgugugjhuhjggg
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -74,28 +74,22 @@ class S3ModelLoader:
|
|
74 |
logging.info(f"Trying to load {model_name} from S3...")
|
75 |
config = AutoConfig.from_pretrained(s3_uri)
|
76 |
model = AutoModelForCausalLM.from_pretrained(s3_uri, config=config)
|
77 |
-
tokenizer = AutoTokenizer.from_pretrained(s3_uri)
|
78 |
|
79 |
if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
|
80 |
-
|
81 |
-
tokenizer.pad_token_id = config.pad_token_id
|
82 |
-
else:
|
83 |
-
tokenizer.pad_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 0
|
84 |
|
85 |
logging.info(f"Loaded {model_name} from S3 successfully.")
|
86 |
return model, tokenizer
|
87 |
except EnvironmentError:
|
88 |
logging.info(f"Model {model_name} not found in S3. Downloading...")
|
89 |
try:
|
90 |
-
|
91 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
92 |
-
|
|
|
93 |
if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
|
94 |
-
|
95 |
-
if config.pad_token_id is not None:
|
96 |
-
tokenizer.pad_token_id = config.pad_token_id
|
97 |
-
else:
|
98 |
-
tokenizer.pad_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 0
|
99 |
|
100 |
logging.info(f"Downloaded {model_name} successfully.")
|
101 |
logging.info(f"Saving {model_name} to S3...")
|
@@ -128,10 +122,7 @@ async def generate(request: Request, body: GenerateRequest):
|
|
128 |
top_k=validated_body.top_k,
|
129 |
repetition_penalty=validated_body.repetition_penalty,
|
130 |
do_sample=validated_body.do_sample,
|
131 |
-
num_return_sequences=validated_body.num_return_sequences
|
132 |
-
stopping_criteria=StoppingCriteriaList(
|
133 |
-
[lambda _, outputs: tokenizer.decode(outputs[0][-1], skip_special_tokens=True) in validated_body.stop_sequences] if validated_body.stop_sequences else []
|
134 |
-
)
|
135 |
)
|
136 |
|
137 |
async def stream_text():
|
@@ -149,7 +140,11 @@ async def generate(request: Request, body: GenerateRequest):
|
|
149 |
|
150 |
generation_config.max_new_tokens = min(remaining_tokens, validated_body.max_new_tokens)
|
151 |
|
152 |
-
|
|
|
|
|
|
|
|
|
153 |
chunk = tokenizer.decode(output[0], skip_special_tokens=True)
|
154 |
generated_text += chunk
|
155 |
yield chunk
|
|
|
74 |
logging.info(f"Trying to load {model_name} from S3...")
|
75 |
config = AutoConfig.from_pretrained(s3_uri)
|
76 |
model = AutoModelForCausalLM.from_pretrained(s3_uri, config=config)
|
77 |
+
tokenizer = AutoTokenizer.from_pretrained(s3_uri, config=config)
|
78 |
|
79 |
if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
|
80 |
+
tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
|
|
|
|
|
|
|
81 |
|
82 |
logging.info(f"Loaded {model_name} from S3 successfully.")
|
83 |
return model, tokenizer
|
84 |
except EnvironmentError:
|
85 |
logging.info(f"Model {model_name} not found in S3. Downloading...")
|
86 |
try:
|
87 |
+
config = AutoConfig.from_pretrained(model_name)
|
88 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, config=config)
|
89 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, config=config, token=HUGGINGFACE_HUB_TOKEN)
|
90 |
+
|
91 |
if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
|
92 |
+
tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
|
|
|
|
|
|
|
|
|
93 |
|
94 |
logging.info(f"Downloaded {model_name} successfully.")
|
95 |
logging.info(f"Saving {model_name} to S3...")
|
|
|
122 |
top_k=validated_body.top_k,
|
123 |
repetition_penalty=validated_body.repetition_penalty,
|
124 |
do_sample=validated_body.do_sample,
|
125 |
+
num_return_sequences=validated_body.num_return_sequences
|
|
|
|
|
|
|
126 |
)
|
127 |
|
128 |
async def stream_text():
|
|
|
140 |
|
141 |
generation_config.max_new_tokens = min(remaining_tokens, validated_body.max_new_tokens)
|
142 |
|
143 |
+
stopping_criteria = StoppingCriteriaList(
|
144 |
+
[lambda _, outputs: tokenizer.decode(outputs[0][-1], skip_special_tokens=True) in validated_body.stop_sequences] if validated_body.stop_sequences else []
|
145 |
+
)
|
146 |
+
|
147 |
+
output = model.generate(**encoded_input, generation_config=generation_config, stopping_criteria=stopping_criteria)
|
148 |
chunk = tokenizer.decode(output[0], skip_special_tokens=True)
|
149 |
generated_text += chunk
|
150 |
yield chunk
|