Hjgugugjhuhjggg commited on
Commit
09e6b0b
·
verified ·
1 Parent(s): dc7b165

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import json
3
  import logging
4
  import boto3
5
- from fastapi import FastAPI, HTTPException
6
  from fastapi.responses import JSONResponse
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
  from huggingface_hub import hf_hub_download
@@ -183,7 +183,7 @@ def continue_generation(input_text, model, tokenizer, max_tokens=MAX_TOKENS):
183
  return generated_text
184
 
185
  @app.post("/generate")
186
- async def generate_text(model_name: str, input_text: str):
187
  try:
188
  model_loader = S3DirectStream(S3_BUCKET_NAME)
189
  model = await model_loader.load_model_from_s3(model_name)
@@ -196,7 +196,7 @@ async def generate_text(model_name: str, input_text: str):
196
  return {"generated_text": generated_text}
197
 
198
  except Exception as e:
199
- return JSONResponse(status_code=500, content={"detail": str(e)})
200
 
201
  if __name__ == "__main__":
202
  import uvicorn
 
2
  import json
3
  import logging
4
  import boto3
5
+ from fastapi import FastAPI, HTTPException, Query
6
  from fastapi.responses import JSONResponse
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
  from huggingface_hub import hf_hub_download
 
183
  return generated_text
184
 
185
  @app.post("/generate")
186
+ async def generate_text(model_name: str = Query(...), input_text: str = Query(...)):
187
  try:
188
  model_loader = S3DirectStream(S3_BUCKET_NAME)
189
  model = await model_loader.load_model_from_s3(model_name)
 
196
  return {"generated_text": generated_text}
197
 
198
  except Exception as e:
199
+ raise HTTPException(status_code=500, detail=str(e))
200
 
201
  if __name__ == "__main__":
202
  import uvicorn