sugiv commited on
Commit
e1b0723
·
1 Parent(s): 5b484f5

Leetmonkey In Action via Inference

Browse files
Files changed (1) hide show
  1. app.py +20 -9
app.py CHANGED
@@ -11,6 +11,7 @@ from typing import Generator
11
  from fastapi import FastAPI, HTTPException, Depends
12
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
13
  from pydantic import BaseModel
 
14
 
15
  # Set up logging
16
  logging.basicConfig(level=logging.INFO)
@@ -37,6 +38,7 @@ generation_kwargs = {
37
  "repeat_penalty": 1.1
38
  }
39
 
 
40
  def download_model(model_name: str) -> str:
41
  logger.info(f"Downloading model: {model_name}")
42
  model_path = hf_hub_download(
@@ -51,15 +53,21 @@ def download_model(model_name: str) -> str:
51
 
52
  # Download and load the 8-bit model at startup
53
  model_path = download_model(MODEL_NAME)
54
- llm = Llama(
55
- model_path=model_path,
56
- n_ctx=2048,
57
- n_threads=4,
58
- n_gpu_layers=-1, # Use all available GPU layers
59
- verbose=False
60
- )
 
 
 
 
 
61
  logger.info("8-bit model loaded successfully")
62
 
 
63
  def generate_solution(instruction: str) -> str:
64
  system_prompt = "You are a Python coding assistant specialized in solving LeetCode problems. Provide only the complete implementation of the given function. Ensure proper indentation and formatting. Do not include any explanations or multiple solutions."
65
  full_prompt = f"""### Instruction:
@@ -127,6 +135,7 @@ def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
127
  raise HTTPException(status_code=401, detail="Invalid token")
128
 
129
  @app.post("/generate_solution")
 
130
  async def generate_solution_api(request: ProblemRequest, authorized: bool = Depends(verify_token)):
131
  logger.info("Generating solution")
132
  generated_output = generate_solution(request.instruction)
@@ -135,6 +144,7 @@ async def generate_solution_api(request: ProblemRequest, authorized: bool = Depe
135
  return {"solution": formatted_code}
136
 
137
  @app.post("/stream_solution")
 
138
  async def stream_solution_api(request: ProblemRequest, authorized: bool = Depends(verify_token)):
139
  async def generate():
140
  logger.info("Streaming solution")
@@ -154,7 +164,7 @@ Here's the complete Python function implementation:
154
 
155
  generated_text = ""
156
  for chunk in llm(full_prompt, stream=True, **generation_kwargs):
157
- token = chunk["choices"][0]["text"]
158
  generated_text += token
159
  yield token
160
 
@@ -166,6 +176,7 @@ Here's the complete Python function implementation:
166
 
167
  # Gradio wrapper for FastAPI
168
  def gradio_wrapper(app):
 
169
  def inference(instruction, token):
170
  import requests
171
  url = "http://localhost:8000/generate_solution"
@@ -197,4 +208,4 @@ if __name__ == "__main__":
197
 
198
  # Launch Gradio interface
199
  iface = gradio_wrapper(app)
200
- iface.launch(share=True)
 
11
  from fastapi import FastAPI, HTTPException, Depends
12
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
13
  from pydantic import BaseModel
14
+ import spaces
15
 
16
  # Set up logging
17
  logging.basicConfig(level=logging.INFO)
 
38
  "repeat_penalty": 1.1
39
  }
40
 
41
+ @spaces.GPU
42
  def download_model(model_name: str) -> str:
43
  logger.info(f"Downloading model: {model_name}")
44
  model_path = hf_hub_download(
 
53
 
54
  # Download and load the 8-bit model at startup
55
  model_path = download_model(MODEL_NAME)
56
+
57
+ @spaces.GPU
58
+ def load_model(model_path):
59
+ return Llama(
60
+ model_path=model_path,
61
+ n_ctx=2048,
62
+ n_threads=4,
63
+ n_gpu_layers=-1, # Use all available GPU layers
64
+ verbose=False
65
+ )
66
+
67
+ llm = load_model(model_path)
68
  logger.info("8-bit model loaded successfully")
69
 
70
+ @spaces.GPU
71
  def generate_solution(instruction: str) -> str:
72
  system_prompt = "You are a Python coding assistant specialized in solving LeetCode problems. Provide only the complete implementation of the given function. Ensure proper indentation and formatting. Do not include any explanations or multiple solutions."
73
  full_prompt = f"""### Instruction:
 
135
  raise HTTPException(status_code=401, detail="Invalid token")
136
 
137
  @app.post("/generate_solution")
138
+ @spaces.GPU
139
  async def generate_solution_api(request: ProblemRequest, authorized: bool = Depends(verify_token)):
140
  logger.info("Generating solution")
141
  generated_output = generate_solution(request.instruction)
 
144
  return {"solution": formatted_code}
145
 
146
  @app.post("/stream_solution")
147
+ @spaces.GPU
148
  async def stream_solution_api(request: ProblemRequest, authorized: bool = Depends(verify_token)):
149
  async def generate():
150
  logger.info("Streaming solution")
 
164
 
165
  generated_text = ""
166
  for chunk in llm(full_prompt, stream=True, **generation_kwargs):
167
+ token = chunk["choices"]["text"]
168
  generated_text += token
169
  yield token
170
 
 
176
 
177
  # Gradio wrapper for FastAPI
178
  def gradio_wrapper(app):
179
+ @spaces.GPU
180
  def inference(instruction, token):
181
  import requests
182
  url = "http://localhost:8000/generate_solution"
 
208
 
209
  # Launch Gradio interface
210
  iface = gradio_wrapper(app)
211
+ iface.launch(share=True)