Empereur-Pirate commited on
Commit
c2fe3af
·
verified ·
1 Parent(s): c5c4414

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +17 -11
main.py CHANGED
@@ -4,9 +4,23 @@ from fastapi.staticfiles import StaticFiles
4
  from transformers import pipeline
5
  from pydantic import BaseModel
6
  from typing import Union
 
 
7
 
8
  app = FastAPI()
9
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  # Serve the static files
11
  app.mount("/static", StaticFiles(directory="static"), name="static")
12
 
@@ -41,16 +55,6 @@ def get_default_inference_endpoint():
41
  def index():
42
  return './static/index.html'
43
 
44
- from typing import Union
45
- from transformers import pipeline
46
-
47
- # Load miku pipeline
48
- pipe_miku = pipeline("text-generation", model="miqudev/miqu-1-70b")
49
-
50
- def miku(input):
51
- output = pipe_miku(input)
52
- return {"output": output[0]["generated_text"]}
53
-
54
  @app.post("/infer_miku")
55
  def infer_endpoint(data: dict):
56
  """Receive input and generate text."""
@@ -62,10 +66,12 @@ def infer_endpoint(data: dict):
62
  if input_text is None:
63
  return {"error": "No input text detected."}
64
  else:
65
- result = miku(input_text)
66
  return result
67
 
68
  @app.get("/infer_miku")
69
  def get_default_inference_endpoint():
70
  return {"message": "Use POST method to submit input data"}
 
 
71
 
 
4
  from transformers import pipeline
5
  from pydantic import BaseModel
6
  from typing import Union
7
+ import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer
9
 
10
  app = FastAPI()
11
 
12
+ # Load the MIKU model and tokenizer
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ model = AutoModelForCausalLM.from_pretrained("miqudev/miqu-1-70b").to(device)
15
+ tokenizer = AutoTokenizer.from_pretrained("miqudev/miqu-1-70b")
16
+
17
+ def miuk_answer(query: str):
18
+ query_tokens = tokenizer.encode(query, return_tensors="pt")
19
+ query_tokens = query_tokens.to(device)
20
+ answer = model.generate(query_tokens, max_length=128, temperature=1, pad_token_id=tokenizer.pad_token_id)
21
+ return tokenizer.decode(answer[:, 0]).replace(" ", "")
22
+
23
+
24
  # Serve the static files
25
  app.mount("/static", StaticFiles(directory="static"), name="static")
26
 
 
55
  def index():
56
  return './static/index.html'
57
 
 
 
 
 
 
 
 
 
 
 
58
  @app.post("/infer_miku")
59
  def infer_endpoint(data: dict):
60
  """Receive input and generate text."""
 
66
  if input_text is None:
67
  return {"error": "No input text detected."}
68
  else:
69
+ result = {"output": miuk_answer(input_text)}
70
  return result
71
 
72
  @app.get("/infer_miku")
73
  def get_default_inference_endpoint():
74
  return {"message": "Use POST method to submit input data"}
75
+
76
+
77