|
import os |
|
import re |
|
import logging |
|
import textwrap |
|
import autopep8 |
|
from huggingface_hub import hf_hub_download |
|
from llama_cpp import Llama |
|
import jwt |
|
from typing import AsyncGenerator |
|
from fastapi import FastAPI, HTTPException, Depends |
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
|
from pydantic import BaseModel |
|
from fastapi.responses import StreamingResponse |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
JWT_SECRET = os.environ.get("JWT_SECRET") |
|
if not JWT_SECRET: |
|
raise ValueError("JWT_SECRET environment variable is not set") |
|
JWT_ALGORITHM = "HS256" |
|
|
|
|
|
MODEL_NAME = "leetmonkey_peft__q8_0.gguf" |
|
REPO_ID = "sugiv/leetmonkey-peft-gguf" |
|
|
|
|
|
generation_kwargs = { |
|
"max_tokens": 2048, |
|
"stop": ["```", "### Instruction:", "### Response:"], |
|
"echo": False, |
|
"temperature": 0.2, |
|
"top_k": 50, |
|
"top_p": 0.95, |
|
"repeat_penalty": 1.1 |
|
} |
|
|
|
def download_model(model_name: str) -> str: |
|
logger.info(f"Downloading model: {model_name}") |
|
model_path = hf_hub_download( |
|
repo_id=REPO_ID, |
|
filename=model_name, |
|
cache_dir="./models", |
|
force_download=True, |
|
resume_download=True |
|
) |
|
logger.info(f"Model downloaded: {model_path}") |
|
return model_path |
|
|
|
|
|
model_path = download_model(MODEL_NAME) |
|
llm = Llama( |
|
model_path=model_path, |
|
n_ctx=2048, |
|
n_threads=4, |
|
n_gpu_layers=-1, |
|
verbose=False |
|
) |
|
logger.info("8-bit model loaded successfully") |
|
|
|
def generate_solution(instruction: str) -> str: |
|
system_prompt = "You are a Python coding assistant specialized in solving LeetCode problems. Provide only the complete implementation of the given function. Ensure proper indentation and formatting. Do not include any explanations or multiple solutions." |
|
full_prompt = f"""### Instruction: |
|
{system_prompt} |
|
|
|
Implement the following function for the LeetCode problem: |
|
|
|
{instruction} |
|
|
|
### Response: |
|
Here's the complete Python function implementation: |
|
|
|
```python |
|
""" |
|
|
|
response = llm(full_prompt, **generation_kwargs) |
|
return response["choices"][0]["text"] |
|
|
|
def extract_and_format_code(text: str) -> str: |
|
|
|
code_match = re.search(r'```python\s*(.*?)\s*```', text, re.DOTALL) |
|
if code_match: |
|
code = code_match.group(1) |
|
else: |
|
code = text |
|
|
|
|
|
code = re.sub(r'^.*?(?=def\s+\w+\s*\()', '', code, flags=re.DOTALL) |
|
|
|
|
|
code = textwrap.dedent(code) |
|
|
|
|
|
lines = code.split('\n') |
|
|
|
|
|
func_def_index = next((i for i, line in enumerate(lines) if line.strip().startswith('def ')), 0) |
|
|
|
|
|
indented_lines = [lines[func_def_index]] |
|
for line in lines[func_def_index + 1:]: |
|
if line.strip(): |
|
indented_lines.append(' ' + line) |
|
else: |
|
indented_lines.append(line) |
|
|
|
formatted_code = '\n'.join(indented_lines) |
|
|
|
try: |
|
return autopep8.fix_code(formatted_code) |
|
except: |
|
return formatted_code |
|
|
|
security = HTTPBearer() |
|
app = FastAPI() |
|
|
|
class ProblemRequest(BaseModel): |
|
instruction: str |
|
|
|
def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)): |
|
try: |
|
jwt.decode(credentials.credentials, JWT_SECRET, algorithms=[JWT_ALGORITHM]) |
|
return True |
|
except jwt.PyJWTError: |
|
raise HTTPException(status_code=401, detail="Invalid token") |
|
|
|
@app.post("/generate_solution") |
|
async def generate_solution_api(request: ProblemRequest, authorized: bool = Depends(verify_token)): |
|
logger.info("Generating solution") |
|
generated_output = generate_solution(request.instruction) |
|
formatted_code = extract_and_format_code(generated_output) |
|
logger.info("Solution generated successfully") |
|
return {"solution": formatted_code} |
|
|
|
@app.post("/stream_solution") |
|
async def stream_solution_api(request: ProblemRequest, authorized: bool = Depends(verify_token)): |
|
async def generate() -> AsyncGenerator[str, None]: |
|
logger.info("Streaming solution") |
|
system_prompt = "You are a Python coding assistant specialized in solving LeetCode problems. Provide only the complete implementation of the given function. Ensure proper indentation and formatting. Do not include any explanations or multiple solutions." |
|
full_prompt = f"""### Instruction: |
|
{system_prompt} |
|
|
|
Implement the following function for the LeetCode problem: |
|
|
|
{request.instruction} |
|
|
|
### Response: |
|
Here's the complete Python function implementation: |
|
|
|
```python |
|
""" |
|
|
|
generated_text = "" |
|
for chunk in llm(full_prompt, stream=True, **generation_kwargs): |
|
token = chunk["choices"][0]["text"] |
|
generated_text += token |
|
yield token |
|
|
|
formatted_code = extract_and_format_code(generated_text) |
|
logger.info("Solution generated successfully") |
|
yield formatted_code |
|
|
|
return StreamingResponse(generate(), media_type="text/plain") |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|