Leetmonkey In Action via Inference
Browse files
app.py
CHANGED
@@ -11,6 +11,7 @@ from typing import Generator
|
|
11 |
from fastapi import FastAPI, HTTPException, Depends
|
12 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
13 |
from pydantic import BaseModel
|
|
|
14 |
|
15 |
# Set up logging
|
16 |
logging.basicConfig(level=logging.INFO)
|
@@ -37,6 +38,7 @@ generation_kwargs = {
|
|
37 |
"repeat_penalty": 1.1
|
38 |
}
|
39 |
|
|
|
40 |
def download_model(model_name: str) -> str:
|
41 |
logger.info(f"Downloading model: {model_name}")
|
42 |
model_path = hf_hub_download(
|
@@ -51,15 +53,21 @@ def download_model(model_name: str) -> str:
|
|
51 |
|
52 |
# Download and load the 8-bit model at startup
|
53 |
model_path = download_model(MODEL_NAME)
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
61 |
logger.info("8-bit model loaded successfully")
|
62 |
|
|
|
63 |
def generate_solution(instruction: str) -> str:
|
64 |
system_prompt = "You are a Python coding assistant specialized in solving LeetCode problems. Provide only the complete implementation of the given function. Ensure proper indentation and formatting. Do not include any explanations or multiple solutions."
|
65 |
full_prompt = f"""### Instruction:
|
@@ -127,6 +135,7 @@ def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
|
127 |
raise HTTPException(status_code=401, detail="Invalid token")
|
128 |
|
129 |
@app.post("/generate_solution")
|
|
|
130 |
async def generate_solution_api(request: ProblemRequest, authorized: bool = Depends(verify_token)):
|
131 |
logger.info("Generating solution")
|
132 |
generated_output = generate_solution(request.instruction)
|
@@ -135,6 +144,7 @@ async def generate_solution_api(request: ProblemRequest, authorized: bool = Depe
|
|
135 |
return {"solution": formatted_code}
|
136 |
|
137 |
@app.post("/stream_solution")
|
|
|
138 |
async def stream_solution_api(request: ProblemRequest, authorized: bool = Depends(verify_token)):
|
139 |
async def generate():
|
140 |
logger.info("Streaming solution")
|
@@ -154,7 +164,7 @@ Here's the complete Python function implementation:
|
|
154 |
|
155 |
generated_text = ""
|
156 |
for chunk in llm(full_prompt, stream=True, **generation_kwargs):
|
157 |
-
token = chunk["choices"][
|
158 |
generated_text += token
|
159 |
yield token
|
160 |
|
@@ -166,6 +176,7 @@ Here's the complete Python function implementation:
|
|
166 |
|
167 |
# Gradio wrapper for FastAPI
|
168 |
def gradio_wrapper(app):
|
|
|
169 |
def inference(instruction, token):
|
170 |
import requests
|
171 |
url = "http://localhost:8000/generate_solution"
|
@@ -197,4 +208,4 @@ if __name__ == "__main__":
|
|
197 |
|
198 |
# Launch Gradio interface
|
199 |
iface = gradio_wrapper(app)
|
200 |
-
iface.launch(share=True)
|
|
|
11 |
from fastapi import FastAPI, HTTPException, Depends
|
12 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
13 |
from pydantic import BaseModel
|
14 |
+
import spaces
|
15 |
|
16 |
# Set up logging
|
17 |
logging.basicConfig(level=logging.INFO)
|
|
|
38 |
"repeat_penalty": 1.1
|
39 |
}
|
40 |
|
41 |
+
@spaces.GPU
|
42 |
def download_model(model_name: str) -> str:
|
43 |
logger.info(f"Downloading model: {model_name}")
|
44 |
model_path = hf_hub_download(
|
|
|
53 |
|
54 |
# Download and load the 8-bit model at startup
|
55 |
model_path = download_model(MODEL_NAME)
|
56 |
+
|
57 |
+
@spaces.GPU
|
58 |
+
def load_model(model_path):
|
59 |
+
return Llama(
|
60 |
+
model_path=model_path,
|
61 |
+
n_ctx=2048,
|
62 |
+
n_threads=4,
|
63 |
+
n_gpu_layers=-1, # Use all available GPU layers
|
64 |
+
verbose=False
|
65 |
+
)
|
66 |
+
|
67 |
+
llm = load_model(model_path)
|
68 |
logger.info("8-bit model loaded successfully")
|
69 |
|
70 |
+
@spaces.GPU
|
71 |
def generate_solution(instruction: str) -> str:
|
72 |
system_prompt = "You are a Python coding assistant specialized in solving LeetCode problems. Provide only the complete implementation of the given function. Ensure proper indentation and formatting. Do not include any explanations or multiple solutions."
|
73 |
full_prompt = f"""### Instruction:
|
|
|
135 |
raise HTTPException(status_code=401, detail="Invalid token")
|
136 |
|
137 |
@app.post("/generate_solution")
|
138 |
+
@spaces.GPU
|
139 |
async def generate_solution_api(request: ProblemRequest, authorized: bool = Depends(verify_token)):
|
140 |
logger.info("Generating solution")
|
141 |
generated_output = generate_solution(request.instruction)
|
|
|
144 |
return {"solution": formatted_code}
|
145 |
|
146 |
@app.post("/stream_solution")
|
147 |
+
@spaces.GPU
|
148 |
async def stream_solution_api(request: ProblemRequest, authorized: bool = Depends(verify_token)):
|
149 |
async def generate():
|
150 |
logger.info("Streaming solution")
|
|
|
164 |
|
165 |
generated_text = ""
|
166 |
for chunk in llm(full_prompt, stream=True, **generation_kwargs):
|
167 |
+
token = chunk["choices"]["text"]
|
168 |
generated_text += token
|
169 |
yield token
|
170 |
|
|
|
176 |
|
177 |
# Gradio wrapper for FastAPI
|
178 |
def gradio_wrapper(app):
|
179 |
+
@spaces.GPU
|
180 |
def inference(instruction, token):
|
181 |
import requests
|
182 |
url = "http://localhost:8000/generate_solution"
|
|
|
208 |
|
209 |
# Launch Gradio interface
|
210 |
iface = gradio_wrapper(app)
|
211 |
+
iface.launch(share=True)
|