sugiv commited on
Commit
8081ccc
1 Parent(s): f18de00

Leetmonkey In Action via Inference

Browse files
Files changed (2) hide show
  1. app.py +112 -68
  2. requirements.txt +3 -5
app.py CHANGED
@@ -1,43 +1,30 @@
1
- import os
2
- import re
3
- import logging
4
- import textwrap
5
- import autopep8
6
  import gradio as gr
7
  from huggingface_hub import hf_hub_download
8
  from llama_cpp import Llama
9
- import jwt
10
- from typing import Dict, Any
 
 
 
 
 
11
 
12
  # Set up logging
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
15
 
16
- # JWT settings
17
- JWT_SECRET = os.environ.get("JWT_SECRET")
18
- if not JWT_SECRET:
19
- raise ValueError("JWT_SECRET environment variable is not set")
20
- JWT_ALGORITHM = "HS256"
21
-
22
- # Model settings
23
- MODEL_NAME = "leetmonkey_peft__q8_0.gguf"
24
- REPO_ID = "sugiv/leetmonkey-peft-gguf"
25
-
26
- # Generation parameters
27
- generation_kwargs = {
28
- "max_tokens": 2048,
29
- "stop": ["```", "### Instruction:", "### Response:"],
30
- "echo": False,
31
- "temperature": 0.2,
32
- "top_k": 50,
33
- "top_p": 0.95,
34
- "repeat_penalty": 1.1
35
  }
36
 
37
- def download_model(model_name: str) -> str:
38
  logger.info(f"Downloading model: {model_name}")
39
  model_path = hf_hub_download(
40
- repo_id=REPO_ID,
41
  filename=model_name,
42
  cache_dir="./models",
43
  force_download=True,
@@ -47,17 +34,32 @@ def download_model(model_name: str) -> str:
47
  return model_path
48
 
49
  # Download and load the 8-bit model at startup
50
- model_path = download_model(MODEL_NAME)
51
  llm = Llama(
52
- model_path=model_path,
53
  n_ctx=2048,
54
  n_threads=4,
55
- n_gpu_layers=-1, # Use all available GPU layers
56
  verbose=False
57
  )
58
  logger.info("8-bit model loaded successfully")
59
 
60
- def generate_solution(instruction: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  system_prompt = "You are a Python coding assistant specialized in solving LeetCode problems. Provide only the complete implementation of the given function. Ensure proper indentation and formatting. Do not include any explanations or multiple solutions."
62
  full_prompt = f"""### Instruction:
63
  {system_prompt}
@@ -72,26 +74,36 @@ Here's the complete Python function implementation:
72
  ```python
73
  """
74
 
75
- response = llm(full_prompt, **generation_kwargs)
76
  return response["choices"][0]["text"]
77
 
78
- def extract_and_format_code(text: str) -> str:
 
79
  code_match = re.search(r'```python\s*(.*?)\s*```', text, re.DOTALL)
80
  if code_match:
81
  code = code_match.group(1)
82
  else:
83
  code = text
84
 
 
85
  code = re.sub(r'^.*?(?=def\s+\w+\s*\()', '', code, flags=re.DOTALL)
 
 
86
  code = textwrap.dedent(code)
 
 
87
  lines = code.split('\n')
 
 
88
  func_def_index = next((i for i, line in enumerate(lines) if line.strip().startswith('def ')), 0)
89
- indented_lines = [lines[func_def_index]]
 
 
90
  for line in lines[func_def_index + 1:]:
91
- if line.strip():
92
- indented_lines.append(' ' + line)
93
  else:
94
- indented_lines.append(line)
95
 
96
  formatted_code = '\n'.join(indented_lines)
97
 
@@ -100,38 +112,70 @@ def extract_and_format_code(text: str) -> str:
100
  except:
101
  return formatted_code
102
 
103
- def verify_token(token: str) -> bool:
104
- try:
105
- jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
106
- return True
107
- except jwt.PyJWTError:
108
- return False
109
-
110
- def generate_code(instruction: str, token: str) -> Dict[str, Any]:
111
- if not verify_token(token):
112
- return {"error": "Invalid token"}
113
 
114
- logger.info("Generating solution")
115
- generated_output = generate_solution(instruction)
116
  formatted_code = extract_and_format_code(generated_output)
117
  logger.info("Solution generated successfully")
118
- return {"solution": formatted_code}
119
-
120
- # Gradio API
121
- api = gr.Interface(
122
- fn=generate_code,
123
- inputs=[
124
- gr.Textbox(label="LeetCode Problem Instruction"),
125
- gr.Textbox(label="JWT Token")
126
- ],
127
- outputs=gr.JSON(),
128
- title="LeetCode Problem Solver API",
129
- description="Provide a LeetCode problem instruction and a valid JWT token to generate a solution.",
130
- examples=[
131
- ["Implement a function to reverse a linked list", "your_jwt_token_here"],
132
- ["Write a function to find the maximum subarray sum", "your_jwt_token_here"]
133
- ]
134
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
  if __name__ == "__main__":
137
- api.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
1
  import gradio as gr
2
  from huggingface_hub import hf_hub_download
3
  from llama_cpp import Llama
4
+ import re
5
+ from datasets import load_dataset
6
+ import random
7
+ import logging
8
+ import os
9
+ import autopep8
10
+ import textwrap
11
 
12
  # Set up logging
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
15
 
16
+ # Define the model options
17
+ gguf_models = {
18
+ "Q8_0 (8-bit)": "leetmonkey_peft__q8_0.gguf",
19
+ "Exact Copy": "leetmonkey_peft_exact_copy.gguf",
20
+ "F16": "leetmonkey_peft_f16.gguf",
21
+ "Super Block Q6": "leetmonkey_peft_super_block_q6.gguf"
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  }
23
 
24
+ def download_model(model_name):
25
  logger.info(f"Downloading model: {model_name}")
26
  model_path = hf_hub_download(
27
+ repo_id="sugiv/leetmonkey-peft-gguf",
28
  filename=model_name,
29
  cache_dir="./models",
30
  force_download=True,
 
34
  return model_path
35
 
36
  # Download and load the 8-bit model at startup
37
+ q8_model_path = download_model(gguf_models["Q8_0 (8-bit)"])
38
  llm = Llama(
39
+ model_path=q8_model_path,
40
  n_ctx=2048,
41
  n_threads=4,
42
+ n_gpu_layers=0,
43
  verbose=False
44
  )
45
  logger.info("8-bit model loaded successfully")
46
 
47
+ # Load the dataset
48
+ dataset = load_dataset("sugiv/leetmonkey_python_dataset")
49
+ train_dataset = dataset["train"]
50
+
51
+ # Generation parameters
52
+ generation_kwargs = {
53
+ "max_tokens": 2048,
54
+ "stop": ["```", "### Instruction:", "### Response:"],
55
+ "echo": False,
56
+ "temperature": 0.2,
57
+ "top_k": 50,
58
+ "top_p": 0.95,
59
+ "repeat_penalty": 1.1
60
+ }
61
+
62
+ def generate_solution(instruction, model):
63
  system_prompt = "You are a Python coding assistant specialized in solving LeetCode problems. Provide only the complete implementation of the given function. Ensure proper indentation and formatting. Do not include any explanations or multiple solutions."
64
  full_prompt = f"""### Instruction:
65
  {system_prompt}
 
74
  ```python
75
  """
76
 
77
+ response = model(full_prompt, **generation_kwargs)
78
  return response["choices"][0]["text"]
79
 
80
+ def extract_and_format_code(text):
81
+ # Extract code between triple backticks
82
  code_match = re.search(r'```python\s*(.*?)\s*```', text, re.DOTALL)
83
  if code_match:
84
  code = code_match.group(1)
85
  else:
86
  code = text
87
 
88
+ # Remove any text before the function definition
89
  code = re.sub(r'^.*?(?=def\s+\w+\s*\()', '', code, flags=re.DOTALL)
90
+
91
+ # Dedent the code to remove any common leading whitespace
92
  code = textwrap.dedent(code)
93
+
94
+ # Split the code into lines
95
  lines = code.split('\n')
96
+
97
+ # Find the function definition line
98
  func_def_index = next((i for i, line in enumerate(lines) if line.strip().startswith('def ')), 0)
99
+
100
+ # Ensure proper indentation
101
+ indented_lines = [lines[func_def_index]] # Keep the function definition as is
102
  for line in lines[func_def_index + 1:]:
103
+ if line.strip(): # If the line is not empty
104
+ indented_lines.append(' ' + line) # Add 4 spaces of indentation
105
  else:
106
+ indented_lines.append(line) # Keep empty lines as is
107
 
108
  formatted_code = '\n'.join(indented_lines)
109
 
 
112
  except:
113
  return formatted_code
114
 
115
+ def select_random_problem():
116
+ return random.choice(train_dataset)['instruction']
117
+
118
+ def update_solution(problem, model_name):
119
+ if model_name == "Q8_0 (8-bit)":
120
+ model = llm
121
+ else:
122
+ model_path = download_model(gguf_models[model_name])
123
+ model = Llama(model_path=model_path, n_ctx=2048, n_threads=4, n_gpu_layers=0, verbose=False)
 
124
 
125
+ logger.info(f"Generating solution using {model_name} model")
126
+ generated_output = generate_solution(problem, model)
127
  formatted_code = extract_and_format_code(generated_output)
128
  logger.info("Solution generated successfully")
129
+ return formatted_code
130
+
131
+ def stream_solution(problem, model_name):
132
+ if model_name == "Q8_0 (8-bit)":
133
+ model = llm
134
+ else:
135
+ model_path = download_model(gguf_models[model_name])
136
+ model = Llama(model_path=model_path, n_ctx=2048, n_threads=4, n_gpu_layers=0, verbose=False)
137
+
138
+ logger.info(f"Generating solution using {model_name} model")
139
+ system_prompt = "You are a Python coding assistant specialized in solving LeetCode problems. Provide only the complete implementation of the given function. Ensure proper indentation and formatting. Do not include any explanations or multiple solutions."
140
+ full_prompt = f"""### Instruction:
141
+ {system_prompt}
142
+
143
+ Implement the following function for the LeetCode problem:
144
+
145
+ {problem}
146
+
147
+ ### Response:
148
+ Here's the complete Python function implementation:
149
+
150
+ ```python
151
+ """
152
+
153
+ generated_text = ""
154
+ for chunk in model(full_prompt, stream=True, **generation_kwargs):
155
+ token = chunk["choices"][0]["text"]
156
+ generated_text += token
157
+ yield generated_text
158
+
159
+ formatted_code = extract_and_format_code(generated_text)
160
+ logger.info("Solution generated successfully")
161
+ yield formatted_code
162
+
163
+ with gr.Blocks() as demo:
164
+ gr.Markdown("# LeetCode Problem Solver")
165
+
166
+ with gr.Row():
167
+ with gr.Column():
168
+ problem_display = gr.Textbox(label="LeetCode Problem", lines=10)
169
+ select_problem_btn = gr.Button("Select Random Problem")
170
+
171
+ with gr.Column():
172
+ model_dropdown = gr.Dropdown(choices=list(gguf_models.keys()), label="Select GGUF Model", value="Q8_0 (8-bit)")
173
+ solution_display = gr.Code(label="Generated Solution", language="python", lines=25)
174
+ generate_btn = gr.Button("Generate Solution")
175
+
176
+ select_problem_btn.click(select_random_problem, outputs=problem_display)
177
+ generate_btn.click(stream_solution, inputs=[problem_display, model_dropdown], outputs=solution_display)
178
 
179
  if __name__ == "__main__":
180
+ logger.info("Starting Gradio interface")
181
+ demo.launch(share=True)
requirements.txt CHANGED
@@ -1,8 +1,6 @@
1
  gradio
2
  llama-cpp-python
3
- huggingface_hub
4
- pyjwt
5
  autopep8
6
- fastapi
7
- uvicorn
8
- pydantic
 
1
  gradio
2
  llama-cpp-python
3
+ datasets
4
+ transformers
5
  autopep8
6
+ huggingface_hub