Empereur-Pirate commited on
Commit
adea8c1
·
verified ·
1 Parent(s): 6f96f84

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +12 -6
main.py CHANGED
@@ -1,4 +1,4 @@
1
- from fastapi import FastAPI, Request
2
  from fastapi.responses import FileResponse, JSONResponse
3
  from fastapi.staticfiles import StaticFiles
4
  from transformers import pipeline
@@ -18,11 +18,14 @@ def t5(input: str) -> dict[str, str]:
18
  output = pipe_flan(input)
19
  return {"output": output[0].get("generated_text", "")}
20
 
 
 
 
21
  @app.post("/infer_t5")
22
- async def infer_endpoint(data: dict = Depends(parse_raw)):
23
  """Receive input and generate text."""
24
  try:
25
- input_text = data.get("input")
26
 
27
  # Validate that the input is a string
28
  assert isinstance(input_text, str), "Input must be a string."
@@ -66,10 +69,10 @@ def miuk_answer(query: str) -> str:
66
  return tokenizer.decode(answer[:, 0]).replace(" ", "")
67
 
68
  @app.post("/infer_miku")
69
- async def infer_endpoint(data: dict = Depends(parse_raw)):
70
  """Receive input and generate text."""
71
  try:
72
- input_text = data.get("input")
73
 
74
  # Validate that the input is a string
75
  assert isinstance(input_text, str), "Input must be a string."
@@ -85,4 +88,7 @@ async def infer_endpoint(data: dict = Depends(parse_raw)):
85
 
86
  @app.get("/infer_miku")
87
  def get_default_inference_endpoint():
88
- return {"message": "Use POST method to submit input data"}
 
 
 
 
1
+ from fastapi import FastAPI, Request, Depends
2
  from fastapi.responses import FileResponse, JSONResponse
3
  from fastapi.staticfiles import StaticFiles
4
  from transformers import pipeline
 
18
  output = pipe_flan(input)
19
  return {"output": output[0].get("generated_text", "")}
20
 
21
+ class ParseRaw(BaseModel):
22
+ raw: bytes
23
+
24
  @app.post("/infer_t5")
25
+ async def infer_endpoint(data: ParseRaw = Depends(parse_raw)):
26
  """Receive input and generate text."""
27
  try:
28
+ input_text = data.raw.decode("utf-8")
29
 
30
  # Validate that the input is a string
31
  assert isinstance(input_text, str), "Input must be a string."
 
69
  return tokenizer.decode(answer[:, 0]).replace(" ", "")
70
 
71
  @app.post("/infer_miku")
72
+ async def infer_endpoint(data: ParseRaw = Depends(parse_raw)):
73
  """Receive input and generate text."""
74
  try:
75
+ input_text = data.raw.decode("utf-8")
76
 
77
  # Validate that the input is a string
78
  assert isinstance(input_text, str), "Input must be a string."
 
88
 
89
  @app.get("/infer_miku")
90
  def get_default_inference_endpoint():
91
+ return {"message": "Use POST method to submit input data"}
92
+
93
+ def parse_raw(request: Request):
94
+ return ParseRaw(raw=request.body())