Hjgugugjhuhjggg commited on
Commit
c9fd992
·
verified ·
1 Parent(s): f2dfe81

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -2
app.py CHANGED
@@ -35,11 +35,12 @@ class GenerateRequest(BaseModel):
35
  stream: bool = True
36
  top_p: float = 1.0
37
  top_k: int = 50
38
- repetition_penalty: float = 1.0
39
  num_return_sequences: int = 1
40
  do_sample: bool = True
41
  chunk_delay: float = 0.0
42
  stop_sequences: list[str] = []
 
43
 
44
  @field_validator("model_name")
45
  def model_name_cannot_be_empty(cls, v):
@@ -112,6 +113,7 @@ async def generate(request: GenerateRequest):
112
  do_sample = request.do_sample
113
  chunk_delay = request.chunk_delay
114
  stop_sequences = request.stop_sequences
 
115
 
116
  model, tokenizer = await model_loader.load_model_and_tokenizer(model_name)
117
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -125,6 +127,7 @@ async def generate(request: GenerateRequest):
125
  repetition_penalty=repetition_penalty,
126
  do_sample=do_sample,
127
  num_return_sequences=num_return_sequences,
 
128
  )
129
 
130
  return StreamingResponse(
@@ -164,7 +167,8 @@ async def stream_text(model, tokenizer, input_text, generation_config, stop_sequ
164
  stopping_criteria=stopping_criteria,
165
  output_scores=True,
166
  return_dict_in_generate=True,
167
- pad_token_id=tokenizer.pad_token_id
 
168
  )
169
  except IndexError as e:
170
  print(f"IndexError during generation: {e}")
 
35
  stream: bool = True
36
  top_p: float = 1.0
37
  top_k: int = 50
38
+ repetition_penalty: float = 1.1 # Increased default to discourage repetition
39
  num_return_sequences: int = 1
40
  do_sample: bool = True
41
  chunk_delay: float = 0.0
42
  stop_sequences: list[str] = []
43
+ no_repeat_ngram_size: int = 2 # Add parameter to prevent repeating ngrams
44
 
45
  @field_validator("model_name")
46
  def model_name_cannot_be_empty(cls, v):
 
113
  do_sample = request.do_sample
114
  chunk_delay = request.chunk_delay
115
  stop_sequences = request.stop_sequences
116
+ no_repeat_ngram_size = request.no_repeat_ngram_size
117
 
118
  model, tokenizer = await model_loader.load_model_and_tokenizer(model_name)
119
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
127
  repetition_penalty=repetition_penalty,
128
  do_sample=do_sample,
129
  num_return_sequences=num_return_sequences,
130
+ no_repeat_ngram_size=no_repeat_ngram_size, # Added to generation config
131
  )
132
 
133
  return StreamingResponse(
 
167
  stopping_criteria=stopping_criteria,
168
  output_scores=True,
169
  return_dict_in_generate=True,
170
+ pad_token_id=tokenizer.pad_token_id,
171
+ no_repeat_ngram_size=generation_config.no_repeat_ngram_size, # Passed to model.generate
172
  )
173
  except IndexError as e:
174
  print(f"IndexError during generation: {e}")