Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
@@ -25,7 +25,7 @@ def t5(input: str) -> dict[str, str]:
|
|
25 |
class ParseRaw(BaseModel):
|
26 |
raw: bytes
|
27 |
|
28 |
-
@app.post("/infer_t5")
|
29 |
async def infer_endpoint(data: ParseRaw = Depends(parse_raw)):
|
30 |
"""Receive input and generate text."""
|
31 |
try:
|
@@ -43,28 +43,28 @@ async def infer_endpoint(data: ParseRaw = Depends(parse_raw)):
|
|
43 |
except AssertionError as e:
|
44 |
return JSONResponse({"error": f"Invalid Input Format: {e}"}, status_code=400)
|
45 |
|
46 |
-
@app.get("/infer_t5")
|
47 |
def get_default_inference_endpoint():
|
48 |
return {"message": "Use POST method to submit input data"}
|
49 |
|
50 |
-
#
|
51 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
52 |
|
53 |
try:
|
54 |
-
#
|
55 |
-
model_config = AutoConfig.from_pretrained("
|
56 |
-
model = AutoModelForCausalLM.from_pretrained("
|
57 |
-
tokenizer = AutoTokenizer.from_pretrained("
|
58 |
except Exception as e:
|
59 |
print("[WARNING]: Failed to load model and tokenizer conventionally.")
|
60 |
print(f"Exception: {e}")
|
61 |
|
62 |
-
#
|
63 |
-
model_config = AutoConfig.from_pretrained("
|
64 |
|
65 |
-
# Load the model using the
|
66 |
-
model = AutoModelForCausalLM.from_pretrained("
|
67 |
-
tokenizer = AutoTokenizer.from_pretrained("
|
68 |
|
69 |
def miuk_answer(query: str) -> str:
|
70 |
query_tokens = tokenizer.encode(query, return_tensors="pt")
|
@@ -72,7 +72,7 @@ def miuk_answer(query: str) -> str:
|
|
72 |
answer = model.generate(query_tokens, max_length=128, temperature=1, pad_token_id=tokenizer.pad_token_id)
|
73 |
return tokenizer.decode(answer[:, 0]).replace(" ", "")
|
74 |
|
75 |
-
@app.post("/infer_miku")
|
76 |
async def infer_endpoint(data: ParseRaw = Depends(parse_raw)):
|
77 |
"""Receive input and generate text."""
|
78 |
try:
|
@@ -90,6 +90,6 @@ async def infer_endpoint(data: ParseRaw = Depends(parse_raw)):
|
|
90 |
except AssertionError as e:
|
91 |
return JSONResponse({"error": f"Invalid Input Format: {e}"}, status_code=400)
|
92 |
|
93 |
-
@app.get("/infer_miku")
|
94 |
def get_default_inference_endpoint():
|
95 |
return {"message": "Use POST method to submit input data"}
|
|
|
25 |
class ParseRaw(BaseModel):
|
26 |
raw: bytes
|
27 |
|
28 |
+
@app .post("/infer_t5")
|
29 |
async def infer_endpoint(data: ParseRaw = Depends(parse_raw)):
|
30 |
"""Receive input and generate text."""
|
31 |
try:
|
|
|
43 |
except AssertionError as e:
|
44 |
return JSONResponse({"error": f"Invalid Input Format: {e}"}, status_code=400)
|
45 |
|
46 |
+
@app .get("/infer_t5")
|
47 |
def get_default_inference_endpoint():
|
48 |
return {"message": "Use POST method to submit input data"}
|
49 |
|
50 |
+
# Initialize device
|
51 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
52 |
|
53 |
try:
|
54 |
+
# Initiate the model and tokenizer with the corrected pre-trained weights
|
55 |
+
model_config = AutoConfig.from_pretrained("152334H/miqu-1-70b-sf", trust_remote_code=True)
|
56 |
+
model = AutoModelForCausalLM.from_pretrained("152334H/miqu-1-70b-sf", config=model_config).to(device)
|
57 |
+
tokenizer = AutoTokenizer.from_pretrained("152334H/miqu-1-70b-sf")
|
58 |
except Exception as e:
|
59 |
print("[WARNING]: Failed to load model and tokenizer conventionally.")
|
60 |
print(f"Exception: {e}")
|
61 |
|
62 |
+
# Configure a fallback mechanism similar to the original implementation
|
63 |
+
model_config = AutoConfig.from_pretrained("152334H/miqu-1-70b-sf", trust_remote_code=True)
|
64 |
|
65 |
+
# Load the model using the fallback configuration
|
66 |
+
model = AutoModelForCausalLM.from_pretrained("152334H/miqu-1-70b-sf", config=model_config).to(device)
|
67 |
+
tokenizer = AutoTokenizer.from_pretrained("152334H/miqu-1-70b-sf")
|
68 |
|
69 |
def miuk_answer(query: str) -> str:
|
70 |
query_tokens = tokenizer.encode(query, return_tensors="pt")
|
|
|
72 |
answer = model.generate(query_tokens, max_length=128, temperature=1, pad_token_id=tokenizer.pad_token_id)
|
73 |
return tokenizer.decode(answer[:, 0]).replace(" ", "")
|
74 |
|
75 |
+
@app .post("/infer_miku")
|
76 |
async def infer_endpoint(data: ParseRaw = Depends(parse_raw)):
|
77 |
"""Receive input and generate text."""
|
78 |
try:
|
|
|
90 |
except AssertionError as e:
|
91 |
return JSONResponse({"error": f"Invalid Input Format: {e}"}, status_code=400)
|
92 |
|
93 |
+
@app .get("/infer_miku")
|
94 |
def get_default_inference_endpoint():
|
95 |
return {"message": "Use POST method to submit input data"}
|