sugiv commited on
Commit
9c3d676
·
1 Parent(s): 5b97345

Leetmonkey In Action via Inference

Browse files
Files changed (1) hide show
  1. app.py +209 -7
app.py CHANGED
@@ -1,14 +1,216 @@
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
2
  import spaces
3
  import torch
4
 
5
- zero = torch.Tensor([0]).cuda()
6
- print(zero.device) # <-- 'cpu' 🤔
 
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  @spaces.GPU
9
- def greet(n):
10
- print(zero.device) # <-- 'cuda:0' 🤗
11
- return f"Hello {zero + n} Tensor"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- demo = gr.Interface(fn=greet, inputs=gr.Number(), outputs=gr.Text())
14
- demo.launch()
 
 
1
+ import os
2
+ import re
3
+ import logging
4
+ import textwrap
5
+ import autopep8
6
  import gradio as gr
7
+ from huggingface_hub import hf_hub_download
8
+ from llama_cpp import Llama
9
+ import jwt
10
+ from typing import Generator
11
+ from fastapi import FastAPI, HTTPException, Depends
12
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
13
+ from pydantic import BaseModel
14
  import spaces
15
  import torch
16
 
17
+ # Set up logging
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
 
21
+ # JWT settings
22
+ JWT_SECRET = os.environ.get("JWT_SECRET")
23
+ if not JWT_SECRET:
24
+ raise ValueError("JWT_SECRET environment variable is not set")
25
+ JWT_ALGORITHM = "HS256"
26
+
27
+ # Model settings
28
+ MODEL_NAME = "leetmonkey_peft__q8_0.gguf"
29
+ REPO_ID = "sugiv/leetmonkey-peft-gguf"
30
+
31
+ # Generation parameters
32
+ generation_kwargs = {
33
+ "max_tokens": 2048,
34
+ "stop": ["```", "### Instruction:", "### Response:"],
35
+ "echo": False,
36
+ "temperature": 0.2,
37
+ "top_k": 50,
38
+ "top_p": 0.95,
39
+ "repeat_penalty": 1.1
40
+ }
41
+
42
+ @spaces.GPU
43
+ def download_model(model_name: str) -> str:
44
+ logger.info(f"Downloading model: {model_name}")
45
+ model_path = hf_hub_download(
46
+ repo_id=REPO_ID,
47
+ filename=model_name,
48
+ cache_dir="./models",
49
+ force_download=True,
50
+ resume_download=True
51
+ )
52
+ logger.info(f"Model downloaded: {model_path}")
53
+ return model_path
54
+
55
+ # Download and load the 8-bit model at startup
56
+ model_path = download_model(MODEL_NAME)
57
+
58
+ @spaces.GPU
59
+ def load_model(model_path):
60
+ return Llama(
61
+ model_path=model_path,
62
+ n_ctx=2048,
63
+ n_threads=4,
64
+ n_gpu_layers=-1, # Use all available GPU layers
65
+ verbose=False
66
+ )
67
+
68
+ llm = load_model(model_path)
69
+ logger.info("8-bit model loaded successfully")
70
+
71
+ @spaces.GPU
72
+ def generate_solution(instruction: str) -> str:
73
+ system_prompt = "You are a Python coding assistant specialized in solving LeetCode problems. Provide only the complete implementation of the given function. Ensure proper indentation and formatting. Do not include any explanations or multiple solutions."
74
+ full_prompt = f"""### Instruction:
75
+ {system_prompt}
76
+
77
+ Implement the following function for the LeetCode problem:
78
+
79
+ {instruction}
80
+
81
+ ### Response:
82
+ Here's the complete Python function implementation:
83
+
84
+ ```python
85
+ """
86
+
87
+ response = llm(full_prompt, **generation_kwargs)
88
+ return response["choices"][0]["text"]
89
+
90
+ def extract_and_format_code(text: str) -> str:
91
+ # Extract code between triple backticks
92
+ code_match = re.search(r'```python\s*(.*?)\s*```', text, re.DOTALL)
93
+ if code_match:
94
+ code = code_match.group(1)
95
+ else:
96
+ code = text
97
+
98
+ # Remove any text before the function definition
99
+ code = re.sub(r'^.*?(?=def\s+\w+\s*\()', '', code, flags=re.DOTALL)
100
+
101
+ # Dedent the code to remove any common leading whitespace
102
+ code = textwrap.dedent(code)
103
+
104
+ # Split the code into lines
105
+ lines = code.split('\n')
106
+
107
+ # Find the function definition line
108
+ func_def_index = next((i for i, line in enumerate(lines) if line.strip().startswith('def ')), 0)
109
+
110
+ # Ensure proper indentation
111
+ indented_lines = [lines[func_def_index]] # Keep the function definition as is
112
+ for line in lines[func_def_index + 1:]:
113
+ if line.strip(): # If the line is not empty
114
+ indented_lines.append(' ' + line) # Add 4 spaces of indentation
115
+ else:
116
+ indented_lines.append(line) # Keep empty lines as is
117
+
118
+ formatted_code = '\n'.join(indented_lines)
119
+
120
+ try:
121
+ return autopep8.fix_code(formatted_code)
122
+ except:
123
+ return formatted_code
124
+
125
+ security = HTTPBearer()
126
+ app = FastAPI()
127
+
128
+ class ProblemRequest(BaseModel):
129
+ instruction: str
130
+
131
+ def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
132
+ try:
133
+ jwt.decode(credentials.credentials, JWT_SECRET, algorithms=[JWT_ALGORITHM])
134
+ return True
135
+ except jwt.PyJWTError:
136
+ raise HTTPException(status_code=401, detail="Invalid token")
137
+
138
+ @app.post("/generate_solution")
139
  @spaces.GPU
140
+ async def generate_solution_api(request: ProblemRequest, authorized: bool = Depends(verify_token)):
141
+ logger.info("Generating solution")
142
+ generated_output = generate_solution(request.instruction)
143
+ formatted_code = extract_and_format_code(generated_output)
144
+ logger.info("Solution generated successfully")
145
+ return {"solution": formatted_code}
146
+
147
+ @app.post("/stream_solution")
148
+ @spaces.GPU
149
+ async def stream_solution_api(request: ProblemRequest, authorized: bool = Depends(verify_token)):
150
+ async def generate():
151
+ logger.info("Streaming solution")
152
+ system_prompt = "You are a Python coding assistant specialized in solving LeetCode problems. Provide only the complete implementation of the given function. Ensure proper indentation and formatting. Do not include any explanations or multiple solutions."
153
+ full_prompt = f"""### Instruction:
154
+ {system_prompt}
155
+
156
+ Implement the following function for the LeetCode problem:
157
+
158
+ {request.instruction}
159
+
160
+ ### Response:
161
+ Here's the complete Python function implementation:
162
+
163
+ ```python
164
+ """
165
+
166
+ generated_text = ""
167
+ for chunk in llm(full_prompt, stream=True, **generation_kwargs):
168
+ token = chunk["choices"]["text"]
169
+ generated_text += token
170
+ yield token
171
+
172
+ formatted_code = extract_and_format_code(generated_text)
173
+ logger.info("Solution generated successfully")
174
+ yield formatted_code
175
+
176
+ return generate()
177
+
178
+ # Gradio wrapper for FastAPI
179
+ def gradio_wrapper(app):
180
+ @spaces.GPU
181
+ def inference(instruction, token):
182
+ import requests
183
+ url = "http://localhost:8000/generate_solution"
184
+ headers = {"Authorization": f"Bearer {token}"}
185
+ response = requests.post(url, json={"instruction": instruction}, headers=headers)
186
+ if response.status_code == 200:
187
+ return response.json()["solution"]
188
+ else:
189
+ return f"Error: {response.status_code}, {response.text}"
190
+
191
+ iface = gr.Interface(
192
+ fn=inference,
193
+ inputs=[
194
+ gr.Textbox(label="LeetCode Problem Instruction"),
195
+ gr.Textbox(label="JWT Token")
196
+ ],
197
+ outputs=gr.Code(label="Generated Solution"),
198
+ title="LeetCode Problem Solver API",
199
+ description="Enter a LeetCode problem instruction and your JWT token to generate a solution."
200
+ )
201
+ return iface
202
+
203
+ if __name__ == "__main__":
204
+ import uvicorn
205
+ from threading import Thread
206
+
207
+ # Verify GPU availability
208
+ zero = torch.Tensor().cuda()
209
+ print(f"GPU availability: {zero.device}")
210
+
211
+ # Start FastAPI in a separate thread
212
+ Thread(target=lambda: uvicorn.run(app, host="0.0.0.0", port=8000)).start()
213
 
214
+ # Launch Gradio interface
215
+ iface = gradio_wrapper(app)
216
+ iface.launch(share=True)