sugiv commited on
Commit
3ec3dc0
1 Parent(s): 1238dd2

Leetmonkey In Action via Inference

Browse files
Files changed (1) hide show
  1. app.py +14 -65
app.py CHANGED
@@ -3,17 +3,14 @@ import re
3
  import logging
4
  import textwrap
5
  import autopep8
6
- import gradio as gr
7
  from huggingface_hub import hf_hub_download
8
  from llama_cpp import Llama
9
  import jwt
10
- from typing import Generator
11
  from fastapi import FastAPI, HTTPException, Depends
12
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
13
  from pydantic import BaseModel
14
- import spaces
15
- import torch
16
- from threading import Thread
17
 
18
  # Set up logging
19
  logging.basicConfig(level=logging.INFO)
@@ -54,21 +51,15 @@ def download_model(model_name: str) -> str:
54
 
55
  # Download and load the 8-bit model at startup
56
  model_path = download_model(MODEL_NAME)
57
-
58
- @spaces.GPU
59
- def load_model(model_path):
60
- return Llama(
61
- model_path=model_path,
62
- n_ctx=2048,
63
- n_threads=4,
64
- n_gpu_layers=-1, # Use all available GPU layers
65
- verbose=False
66
- )
67
-
68
- llm = load_model(model_path)
69
  logger.info("8-bit model loaded successfully")
70
 
71
- @spaces.GPU
72
  def generate_solution(instruction: str) -> str:
73
  system_prompt = "You are a Python coding assistant specialized in solving LeetCode problems. Provide only the complete implementation of the given function. Ensure proper indentation and formatting. Do not include any explanations or multiple solutions."
74
  full_prompt = f"""### Instruction:
@@ -145,7 +136,7 @@ async def generate_solution_api(request: ProblemRequest, authorized: bool = Depe
145
 
146
  @app.post("/stream_solution")
147
  async def stream_solution_api(request: ProblemRequest, authorized: bool = Depends(verify_token)):
148
- async def generate():
149
  logger.info("Streaming solution")
150
  system_prompt = "You are a Python coding assistant specialized in solving LeetCode problems. Provide only the complete implementation of the given function. Ensure proper indentation and formatting. Do not include any explanations or multiple solutions."
151
  full_prompt = f"""### Instruction:
@@ -163,7 +154,7 @@ Here's the complete Python function implementation:
163
 
164
  generated_text = ""
165
  for chunk in llm(full_prompt, stream=True, **generation_kwargs):
166
- token = chunk["choices"]["text"]
167
  generated_text += token
168
  yield token
169
 
@@ -171,50 +162,8 @@ Here's the complete Python function implementation:
171
  logger.info("Solution generated successfully")
172
  yield formatted_code
173
 
174
- return generate()
175
-
176
- # Gradio wrapper for FastAPI
177
- def gradio_wrapper(app):
178
- def inference(instruction, token):
179
- import requests
180
- url = "http://localhost:8000/generate_solution"
181
- headers = {"Authorization": f"Bearer {token}"}
182
- response = requests.post(url, json={"instruction": instruction}, headers=headers)
183
- if response.status_code == 200:
184
- return response.json()["solution"]
185
- else:
186
- return f"Error: {response.status_code}, {response.text}"
187
-
188
- iface = gr.Interface(
189
- fn=inference,
190
- inputs=[
191
- gr.Textbox(label="LeetCode Problem Instruction"),
192
- gr.Textbox(label="JWT Token")
193
- ],
194
- outputs=gr.Code(label="Generated Solution"),
195
- title="LeetCode Problem Solver API",
196
- description="Enter a LeetCode problem instruction and your JWT token to generate a solution."
197
- )
198
- return iface
199
-
200
- @spaces.GPU
201
- def main():
202
- # Verify GPU availability
203
- zero = torch.Tensor().cuda()
204
- print(f"GPU availability: {zero.device}")
205
-
206
- # Download and load the model
207
- model_path = download_model(MODEL_NAME)
208
- global llm
209
- llm = load_model(model_path)
210
- logger.info("8-bit model loaded successfully")
211
-
212
- # Start FastAPI in a separate thread
213
- Thread(target=lambda: uvicorn.run(app, host="0.0.0.0", port=8000)).start()
214
-
215
- # Launch Gradio interface
216
- iface = gradio_wrapper(app)
217
- iface.launch(share=True)
218
 
219
  if __name__ == "__main__":
220
- main()
 
 
3
  import logging
4
  import textwrap
5
  import autopep8
 
6
  from huggingface_hub import hf_hub_download
7
  from llama_cpp import Llama
8
  import jwt
9
+ from typing import AsyncGenerator
10
  from fastapi import FastAPI, HTTPException, Depends
11
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
12
  from pydantic import BaseModel
13
+ from fastapi.responses import StreamingResponse
 
 
14
 
15
  # Set up logging
16
  logging.basicConfig(level=logging.INFO)
 
51
 
52
  # Download and load the 8-bit model at startup
53
  model_path = download_model(MODEL_NAME)
54
+ llm = Llama(
55
+ model_path=model_path,
56
+ n_ctx=2048,
57
+ n_threads=4,
58
+ n_gpu_layers=-1, # Use all available GPU layers
59
+ verbose=False
60
+ )
 
 
 
 
 
61
  logger.info("8-bit model loaded successfully")
62
 
 
63
  def generate_solution(instruction: str) -> str:
64
  system_prompt = "You are a Python coding assistant specialized in solving LeetCode problems. Provide only the complete implementation of the given function. Ensure proper indentation and formatting. Do not include any explanations or multiple solutions."
65
  full_prompt = f"""### Instruction:
 
136
 
137
  @app.post("/stream_solution")
138
  async def stream_solution_api(request: ProblemRequest, authorized: bool = Depends(verify_token)):
139
+ async def generate() -> AsyncGenerator[str, None]:
140
  logger.info("Streaming solution")
141
  system_prompt = "You are a Python coding assistant specialized in solving LeetCode problems. Provide only the complete implementation of the given function. Ensure proper indentation and formatting. Do not include any explanations or multiple solutions."
142
  full_prompt = f"""### Instruction:
 
154
 
155
  generated_text = ""
156
  for chunk in llm(full_prompt, stream=True, **generation_kwargs):
157
+ token = chunk["choices"][0]["text"]
158
  generated_text += token
159
  yield token
160
 
 
162
  logger.info("Solution generated successfully")
163
  yield formatted_code
164
 
165
+ return StreamingResponse(generate(), media_type="text/plain")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
  if __name__ == "__main__":
168
+ import uvicorn
169
+ uvicorn.run(app, host="0.0.0.0", port=8000)