Hjgugugjhuhjggg commited on
Commit
de3c0e2
·
verified ·
1 Parent(s): c9fd992

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -82
app.py CHANGED
@@ -1,18 +1,11 @@
1
  import os
2
  import torch
3
  from fastapi import FastAPI, HTTPException
4
- from fastapi.responses import StreamingResponse
5
  from pydantic import BaseModel, field_validator
6
- from transformers import (
7
- AutoConfig,
8
- AutoModelForCausalLM,
9
- AutoTokenizer,
10
- GenerationConfig,
11
- StoppingCriteriaList
12
- )
13
  import boto3
14
  import uvicorn
15
- import asyncio
16
  from io import BytesIO
17
  from transformers import pipeline
18
 
@@ -26,21 +19,26 @@ s3_client = boto3.client('s3', aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_a
26
 
27
  app = FastAPI()
28
 
 
 
 
 
 
 
 
29
  class GenerateRequest(BaseModel):
30
  model_name: str
31
  input_text: str = ""
32
  task_type: str
33
  temperature: float = 1.0
34
  max_new_tokens: int = 10
35
- stream: bool = True
36
  top_p: float = 1.0
37
  top_k: int = 50
38
- repetition_penalty: float = 1.1 # Increased default to discourage repetition
39
  num_return_sequences: int = 1
40
  do_sample: bool = True
41
- chunk_delay: float = 0.0
42
  stop_sequences: list[str] = []
43
- no_repeat_ngram_size: int = 2 # Add parameter to prevent repeating ngrams
44
 
45
  @field_validator("model_name")
46
  def model_name_cannot_be_empty(cls, v):
@@ -62,11 +60,11 @@ class GenerateRequest(BaseModel):
62
  return v
63
 
64
  class S3ModelLoader:
65
- def __init__(self, bucket_name, s3_client):
66
  self.bucket_name = bucket_name
67
  self.s3_client = s3_client
68
 
69
- def _get_s3_uri(self, model_name):
70
  return f"s3://{self.bucket_name}/{model_name.replace('/', '-')}"
71
 
72
  async def load_model_and_tokenizer(self, model_name):
@@ -75,20 +73,20 @@ class S3ModelLoader:
75
  config = AutoConfig.from_pretrained(s3_uri, local_files_only=True)
76
  model = AutoModelForCausalLM.from_pretrained(s3_uri, config=config, local_files_only=True)
77
  tokenizer = AutoTokenizer.from_pretrained(s3_uri, config=config, local_files_only=True)
78
-
79
- if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
80
- tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
81
-
82
  return model, tokenizer
83
  except EnvironmentError:
84
  try:
85
  config = AutoConfig.from_pretrained(model_name)
86
  tokenizer = AutoTokenizer.from_pretrained(model_name, config=config)
 
87
  model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
88
-
89
- if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
90
- tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
91
-
92
  model.save_pretrained(s3_uri)
93
  tokenizer.save_pretrained(s3_uri)
94
  return model, tokenizer
@@ -105,13 +103,11 @@ async def generate(request: GenerateRequest):
105
  task_type = request.task_type
106
  temperature = request.temperature
107
  max_new_tokens = request.max_new_tokens
108
- stream = request.stream
109
  top_p = request.top_p
110
  top_k = request.top_k
111
  repetition_penalty = request.repetition_penalty
112
  num_return_sequences = request.num_return_sequences
113
  do_sample = request.do_sample
114
- chunk_delay = request.chunk_delay
115
  stop_sequences = request.stop_sequences
116
  no_repeat_ngram_size = request.no_repeat_ngram_size
117
 
@@ -127,74 +123,41 @@ async def generate(request: GenerateRequest):
127
  repetition_penalty=repetition_penalty,
128
  do_sample=do_sample,
129
  num_return_sequences=num_return_sequences,
130
- no_repeat_ngram_size=no_repeat_ngram_size, # Added to generation config
 
131
  )
132
 
133
- return StreamingResponse(
134
- stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay),
135
- media_type="text/plain"
136
- )
137
 
138
  except Exception as e:
139
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
140
 
141
- async def stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay):
142
  max_model_length = model.config.max_position_embeddings
143
  encoded_input = tokenizer(input_text, return_tensors="pt", max_length=max_model_length, truncation=True).to(device)
144
 
145
- def stop_criteria(input_ids, scores):
146
- decoded_output = tokenizer.decode(input_ids[0], skip_special_tokens=True)
147
- for stop in stop_sequences:
148
- if decoded_output.endswith(stop):
149
- return True
150
- return False
151
 
152
- stopping_criteria = StoppingCriteriaList([stop_criteria])
 
 
 
 
 
 
153
 
154
- token_buffer = []
155
- output_ids = encoded_input.input_ids
156
- while True:
157
- try:
158
- outputs = model.generate(
159
- output_ids,
160
- do_sample=generation_config.do_sample,
161
- max_new_tokens=generation_config.max_new_tokens,
162
- temperature=generation_config.temperature,
163
- top_p=generation_config.top_p,
164
- top_k=generation_config.top_k,
165
- repetition_penalty=generation_config.repetition_penalty,
166
- num_return_sequences=generation_config.num_return_sequences,
167
- stopping_criteria=stopping_criteria,
168
- output_scores=True,
169
- return_dict_in_generate=True,
170
- pad_token_id=tokenizer.pad_token_id,
171
- no_repeat_ngram_size=generation_config.no_repeat_ngram_size, # Passed to model.generate
172
- )
173
- except IndexError as e:
174
- print(f"IndexError during generation: {e}")
175
- break
176
-
177
- new_token_ids = outputs.sequences[0][encoded_input.input_ids.shape[-1]:]
178
-
179
- for token_id in new_token_ids:
180
- token = tokenizer.decode(token_id, skip_special_tokens=True)
181
- token_buffer.append(token)
182
- if len(token_buffer) >= 10:
183
- yield "".join(token_buffer)
184
- token_buffer = []
185
- await asyncio.sleep(chunk_delay)
186
-
187
- if token_buffer:
188
- yield "".join(token_buffer)
189
- token_buffer = []
190
-
191
- if stop_criteria(outputs.sequences, None):
192
- break
193
-
194
- if len(new_token_ids) < generation_config.max_new_tokens:
195
- break
196
-
197
- output_ids = outputs.sequences
198
 
199
  @app.post("/generate-image")
200
  async def generate_image(request: GenerateRequest):
 
1
  import os
2
  import torch
3
  from fastapi import FastAPI, HTTPException
4
+ from fastapi.responses import JSONResponse
5
  from pydantic import BaseModel, field_validator
6
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, StoppingCriteriaList
 
 
 
 
 
 
7
  import boto3
8
  import uvicorn
 
9
  from io import BytesIO
10
  from transformers import pipeline
11
 
 
19
 
20
  app = FastAPI()
21
 
22
+ SPECIAL_TOKENS = {
23
+ "bos_token": "<|startoftext|>",
24
+ "eos_token": "<|endoftext|>",
25
+ "pad_token": "[PAD]",
26
+ "unk_token": "[UNK]",
27
+ }
28
+
29
  class GenerateRequest(BaseModel):
30
  model_name: str
31
  input_text: str = ""
32
  task_type: str
33
  temperature: float = 1.0
34
  max_new_tokens: int = 10
 
35
  top_p: float = 1.0
36
  top_k: int = 50
37
+ repetition_penalty: float = 1.1
38
  num_return_sequences: int = 1
39
  do_sample: bool = True
 
40
  stop_sequences: list[str] = []
41
+ no_repeat_ngram_size: int = 2
42
 
43
  @field_validator("model_name")
44
  def model_name_cannot_be_empty(cls, v):
 
60
  return v
61
 
62
  class S3ModelLoader:
63
+ def.__init__(self, bucket_name, s3_client):
64
  self.bucket_name = bucket_name
65
  self.s3_client = s3_client
66
 
67
+ def._get_s3_uri(self, model_name):
68
  return f"s3://{self.bucket_name}/{model_name.replace('/', '-')}"
69
 
70
  async def load_model_and_tokenizer(self, model_name):
 
73
  config = AutoConfig.from_pretrained(s3_uri, local_files_only=True)
74
  model = AutoModelForCausalLM.from_pretrained(s3_uri, config=config, local_files_only=True)
75
  tokenizer = AutoTokenizer.from_pretrained(s3_uri, config=config, local_files_only=True)
76
+ tokenizer.add_special_tokens(SPECIAL_TOKENS)
77
+ model.resize_token_embeddings(len(tokenizer))
78
+ if tokenizer.pad_token_id is None:
79
+ tokenizer.pad_token_id = tokenizer.eos_token_id
80
  return model, tokenizer
81
  except EnvironmentError:
82
  try:
83
  config = AutoConfig.from_pretrained(model_name)
84
  tokenizer = AutoTokenizer.from_pretrained(model_name, config=config)
85
+ tokenizer.add_special_tokens(SPECIAL_TOKENS)
86
  model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
87
+ model.resize_token_embeddings(len(tokenizer))
88
+ if tokenizer.pad_token_id is None:
89
+ tokenizer.pad_token_id = tokenizer.eos_token_id
 
90
  model.save_pretrained(s3_uri)
91
  tokenizer.save_pretrained(s3_uri)
92
  return model, tokenizer
 
103
  task_type = request.task_type
104
  temperature = request.temperature
105
  max_new_tokens = request.max_new_tokens
 
106
  top_p = request.top_p
107
  top_k = request.top_k
108
  repetition_penalty = request.repetition_penalty
109
  num_return_sequences = request.num_return_sequences
110
  do_sample = request.do_sample
 
111
  stop_sequences = request.stop_sequences
112
  no_repeat_ngram_size = request.no_repeat_ngram_size
113
 
 
123
  repetition_penalty=repetition_penalty,
124
  do_sample=do_sample,
125
  num_return_sequences=num_return_sequences,
126
+ no_repeat_ngram_size=no_repeat_ngram_size,
127
+ pad_token_id=tokenizer.pad_token_id
128
  )
129
 
130
+ generated_text = generate_text(model, tokenizer, input_text, generation_config, stop_sequences, device)
131
+ return JSONResponse({"text": generated_text})
 
 
132
 
133
  except Exception as e:
134
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
135
 
136
+ def generate_text(model, tokenizer, input_text, generation_config, stop_sequences, device):
137
  max_model_length = model.config.max_position_embeddings
138
  encoded_input = tokenizer(input_text, return_tensors="pt", max_length=max_model_length, truncation=True).to(device)
139
 
140
+ stopping_criteria = StoppingCriteriaList()
 
 
 
 
 
141
 
142
+ class CustomStoppingCriteria(StoppingCriteriaList):
143
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
144
+ decoded_output = tokenizer.decode(input_ids[0], skip_special_tokens=True)
145
+ for stop in stop_sequences:
146
+ if decoded_output.endswith(stop):
147
+ return True
148
+ return False
149
 
150
+ stopping_criteria.append(CustomStoppingCriteria())
151
+
152
+ outputs = model.generate(
153
+ encoded_input.input_ids,
154
+ generation_config=generation_config,
155
+ stopping_criteria=stopping_criteria,
156
+ pad_token_id=generation_config.pad_token_id
157
+ )
158
+
159
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
160
+ return generated_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
  @app.post("/generate-image")
163
  async def generate_image(request: GenerateRequest):