Empereur-Pirate commited on
Commit
44a140e
·
verified ·
1 Parent(s): 0551907

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +14 -14
main.py CHANGED
@@ -25,7 +25,7 @@ def t5(input: str) -> dict[str, str]:
25
  class ParseRaw(BaseModel):
26
  raw: bytes
27
 
28
- @app.post("/infer_t5")
29
  async def infer_endpoint(data: ParseRaw = Depends(parse_raw)):
30
  """Receive input and generate text."""
31
  try:
@@ -43,28 +43,28 @@ async def infer_endpoint(data: ParseRaw = Depends(parse_raw)):
43
  except AssertionError as e:
44
  return JSONResponse({"error": f"Invalid Input Format: {e}"}, status_code=400)
45
 
46
- @app.get("/infer_t5")
47
  def get_default_inference_endpoint():
48
  return {"message": "Use POST method to submit input data"}
49
 
50
- # Load the MIKU model and tokenizer
51
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52
 
53
  try:
54
- # Attempt to load the model and tokenizer regularly
55
- model_config = AutoConfig.from_pretrained("miqudev/miqu-1-70b")
56
- model = AutoModelForCausalLM.from_pretrained("miqudev/miqu-1-70b", config=model_config).to(device)
57
- tokenizer = AutoTokenizer.from_pretrained("miqudev/miqu-1-70b")
58
  except Exception as e:
59
  print("[WARNING]: Failed to load model and tokenizer conventionally.")
60
  print(f"Exception: {e}")
61
 
62
- # Construct a dummy configuration object
63
- model_config = AutoConfig.from_pretrained("miqudev/miqu-1-70b", trust_remote_code=True)
64
 
65
- # Load the model using the constructed configuration
66
- model = AutoModelForCausalLM.from_pretrained("miqudev/miqu-1-70b", config=model_config).to(device)
67
- tokenizer = AutoTokenizer.from_pretrained("miqudev/miqu-1-70b")
68
 
69
  def miuk_answer(query: str) -> str:
70
  query_tokens = tokenizer.encode(query, return_tensors="pt")
@@ -72,7 +72,7 @@ def miuk_answer(query: str) -> str:
72
  answer = model.generate(query_tokens, max_length=128, temperature=1, pad_token_id=tokenizer.pad_token_id)
73
  return tokenizer.decode(answer[:, 0]).replace(" ", "")
74
 
75
- @app.post("/infer_miku")
76
  async def infer_endpoint(data: ParseRaw = Depends(parse_raw)):
77
  """Receive input and generate text."""
78
  try:
@@ -90,6 +90,6 @@ async def infer_endpoint(data: ParseRaw = Depends(parse_raw)):
90
  except AssertionError as e:
91
  return JSONResponse({"error": f"Invalid Input Format: {e}"}, status_code=400)
92
 
93
- @app.get("/infer_miku")
94
  def get_default_inference_endpoint():
95
  return {"message": "Use POST method to submit input data"}
 
25
  class ParseRaw(BaseModel):
26
  raw: bytes
27
 
28
+ @app .post("/infer_t5")
29
  async def infer_endpoint(data: ParseRaw = Depends(parse_raw)):
30
  """Receive input and generate text."""
31
  try:
 
43
  except AssertionError as e:
44
  return JSONResponse({"error": f"Invalid Input Format: {e}"}, status_code=400)
45
 
46
+ @app .get("/infer_t5")
47
  def get_default_inference_endpoint():
48
  return {"message": "Use POST method to submit input data"}
49
 
50
+ # Initialize device
51
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52
 
53
  try:
54
+ # Initiate the model and tokenizer with the corrected pre-trained weights
55
+ model_config = AutoConfig.from_pretrained("152334H/miqu-1-70b-sf", trust_remote_code=True)
56
+ model = AutoModelForCausalLM.from_pretrained("152334H/miqu-1-70b-sf", config=model_config).to(device)
57
+ tokenizer = AutoTokenizer.from_pretrained("152334H/miqu-1-70b-sf")
58
  except Exception as e:
59
  print("[WARNING]: Failed to load model and tokenizer conventionally.")
60
  print(f"Exception: {e}")
61
 
62
+ # Configure a fallback mechanism similar to the original implementation
63
+ model_config = AutoConfig.from_pretrained("152334H/miqu-1-70b-sf", trust_remote_code=True)
64
 
65
+ # Load the model using the fallback configuration
66
+ model = AutoModelForCausalLM.from_pretrained("152334H/miqu-1-70b-sf", config=model_config).to(device)
67
+ tokenizer = AutoTokenizer.from_pretrained("152334H/miqu-1-70b-sf")
68
 
69
  def miuk_answer(query: str) -> str:
70
  query_tokens = tokenizer.encode(query, return_tensors="pt")
 
72
  answer = model.generate(query_tokens, max_length=128, temperature=1, pad_token_id=tokenizer.pad_token_id)
73
  return tokenizer.decode(answer[:, 0]).replace(" ", "")
74
 
75
+ @app .post("/infer_miku")
76
  async def infer_endpoint(data: ParseRaw = Depends(parse_raw)):
77
  """Receive input and generate text."""
78
  try:
 
90
  except AssertionError as e:
91
  return JSONResponse({"error": f"Invalid Input Format: {e}"}, status_code=400)
92
 
93
+ @app .get("/infer_miku")
94
  def get_default_inference_endpoint():
95
  return {"message": "Use POST method to submit input data"}