Hjgugugjhuhjggg commited on
Commit
c17efbf
·
verified ·
1 Parent(s): 6e229a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -38
app.py CHANGED
@@ -6,8 +6,7 @@ from pydantic import BaseModel, field_validator
6
  from transformers import (
7
  AutoConfig,
8
  pipeline,
9
- AutoModelForSeq2SeqLM,
10
- AutoModelForCausalLM,
11
  AutoTokenizer,
12
  GenerationConfig,
13
  StoppingCriteriaList
@@ -83,18 +82,10 @@ class S3ModelLoader:
83
  s3_uri, local_files_only=True
84
  )
85
 
86
- if "llama" in model_name:
87
- model = AutoModelForCausalLM.from_pretrained(
88
- s3_uri, config=config, local_files_only=True, rope_scaling = {"type": "linear", "factor": 2.0}
89
- )
90
- elif 'qwen' in model_name:
91
- model = AutoModelForCausalLM.from_pretrained(
92
- s3_uri, config=config, local_files_only=True
93
- )
94
- else:
95
- model = AutoModelForSeq2SeqLM.from_pretrained(
96
  s3_uri, config=config, local_files_only=True
97
  )
 
98
 
99
  tokenizer = AutoTokenizer.from_pretrained(
100
  s3_uri, config=config, local_files_only=True
@@ -115,19 +106,11 @@ class S3ModelLoader:
115
  model_name, config=config, token=HUGGINGFACE_HUB_TOKEN
116
  )
117
 
118
- if "llama" in model_name:
119
- model = AutoModelForCausalLM.from_pretrained(
120
- model_name, config=config, token=HUGGINGFACE_HUB_TOKEN, rope_scaling = {"type": "linear", "factor": 2.0}
121
- )
122
- elif 'qwen' in model_name:
123
- model = AutoModelForCausalLM.from_pretrained(
124
- model_name, config=config, token=HUGGINGFACE_HUB_TOKEN
125
- )
126
- else:
127
- model = AutoModelForSeq2SeqLM.from_pretrained(
128
  model_name, config=config, token=HUGGINGFACE_HUB_TOKEN
129
  )
130
 
 
131
  if tokenizer.eos_token_id is not None and \
132
  tokenizer.pad_token_id is None:
133
  tokenizer.pad_token_id = config.pad_token_id \
@@ -164,23 +147,26 @@ async def generate(request: GenerateRequest):
164
  load_model_and_tokenizer(model_name)
165
  device = "cuda" if torch.cuda.is_available() else "cpu"
166
  model.to(device)
 
 
 
 
 
 
 
 
 
 
 
167
 
168
- generation_config = GenerationConfig(
169
- temperature=temperature,
170
- max_new_tokens=max_new_tokens,
171
- top_p=top_p,
172
- top_k=top_k,
173
- repetition_penalty=repetition_penalty,
174
- do_sample=do_sample,
175
- num_return_sequences=num_return_sequences,
176
- )
177
-
178
- return StreamingResponse(
179
- stream_text(model, tokenizer, input_text,
180
- generation_config, stop_sequences,
181
- device, chunk_delay),
182
- media_type="text/plain"
183
- )
184
 
185
  except Exception as e:
186
  raise HTTPException(
 
6
  from transformers import (
7
  AutoConfig,
8
  pipeline,
9
+ AutoModel,
 
10
  AutoTokenizer,
11
  GenerationConfig,
12
  StoppingCriteriaList
 
82
  s3_uri, local_files_only=True
83
  )
84
 
85
+ model = AutoModel.from_pretrained(
 
 
 
 
 
 
 
 
 
86
  s3_uri, config=config, local_files_only=True
87
  )
88
+
89
 
90
  tokenizer = AutoTokenizer.from_pretrained(
91
  s3_uri, config=config, local_files_only=True
 
106
  model_name, config=config, token=HUGGINGFACE_HUB_TOKEN
107
  )
108
 
109
+ model = AutoModel.from_pretrained(
 
 
 
 
 
 
 
 
 
110
  model_name, config=config, token=HUGGINGFACE_HUB_TOKEN
111
  )
112
 
113
+
114
  if tokenizer.eos_token_id is not None and \
115
  tokenizer.pad_token_id is None:
116
  tokenizer.pad_token_id = config.pad_token_id \
 
147
  load_model_and_tokenizer(model_name)
148
  device = "cuda" if torch.cuda.is_available() else "cpu"
149
  model.to(device)
150
+
151
+ if "text-to-text" == task_type:
152
+ generation_config = GenerationConfig(
153
+ temperature=temperature,
154
+ max_new_tokens=max_new_tokens,
155
+ top_p=top_p,
156
+ top_k=top_k,
157
+ repetition_penalty=repetition_penalty,
158
+ do_sample=do_sample,
159
+ num_return_sequences=num_return_sequences,
160
+ )
161
 
162
+ return StreamingResponse(
163
+ stream_text(model, tokenizer, input_text,
164
+ generation_config, stop_sequences,
165
+ device, chunk_delay),
166
+ media_type="text/plain"
167
+ )
168
+ else:
169
+ return HTTPException(status_code=400, detail="Task type not text-to-text")
 
 
 
 
 
 
 
 
170
 
171
  except Exception as e:
172
  raise HTTPException(