MathsPro / app.py
leolaish's picture
Update app.py
cb97892 verified
import gradio as gr
from dataclasses import dataclass
from concurrent.futures import ThreadPoolExecutor, TimeoutError
from huggingface_hub import InferenceClient
import os
import re
import subprocess
import tempfile
import json
import datasets
from datasets import load_dataset
from datasets import Value, Features
import random
import time
from typing import Tuple, Dict, Any, List
from sympy import N, simplify
from sympy.parsing.latex import parse_latex
#from openai import OpenAI
import base64
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers import AutoTokenizer, AutoModelForPreTraining
from langchain_community.llms.manifest import ManifestWrapper
#client = OpenAI(
# base_url=os.environ.get("SERVER_URL"),
# api_key=os.environ.get("HF_TOKEN"),
#)
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
@dataclass
class Config:
debug: bool = False
push_to_hub: bool = False
model_id: str = None
revision: str = None
system_prompt: str = None
validation_set: str = None
is_quantized: bool = False
restart_on_fail: bool = False
is_submission: bool = False
num_samples: int = 1
num_generations: int = 1
do_sample: bool = True
temperature: float = 1.0
top_p: float = 0.9
top_k: int = 50
max_new_tokens: int = 100
# Load pre-trained Wit Transformer model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("AnReu/math_pretrained_bert")
model = AutoModelForPreTraining.from_pretrained("AnReu/math_pretrained_bert")
class PythonREPL:
def __init__(self, timeout=5):
self.timeout = timeout
def execute(self, query: str) -> Tuple[bool, str]:
query = "import math\nimport numpy as np\nimport sympy as sp\n" + query
query = query.strip().split("\n")
if "print(" not in query[-1]:
if "#" in query[-1]:
query[-1] = query[-1].split("#")[0]
query[-1] = "print(" + query[-1] + ")"
query = "\n".join(query)
with tempfile.TemporaryDirectory() as temp_dir:
temp_file_path = os.path.join(temp_dir, "tmp.py")
with open(temp_file_path, "w") as f:
f.write(query)
result = subprocess.run(
["python3", temp_file_path],
capture_output=True,
check=False,
text=True,
timeout=self.timeout,
)
if result.returncode == 0:
output = result.stdout
return True, output.strip()
else:
error_msg = result.stderr.strip()
msgs = error_msg.split("\n")
new_msgs = []
want_next = False
for m in msgs:
if "Traceback" in m:
new_msgs.append(m)
elif m == msgs[-1]:
new_msgs.append(m)
elif temp_file_path in m:
st = m.index('"/') + 1 if '"/' in m else 0
ed = m.index(temp_file_path) + 1 if temp_file_path in m else None
clr = m[st:ed] if not ed else m[st:]
m = m.replace(clr, "")
new_msgs.append(m)
want_next = True
elif want_next:
new_msgs.append(m)
want_next = False
error_msg = "\n".join(new_msgs)
return False, error_msg.strip()
def __call__(self, query: str) -> Tuple[bool, str]:
with ThreadPoolExecutor() as executor:
future = executor.submit(self.execute, query)
try:
return future.result(timeout=self.timeout)
except TimeoutError:
return False, f"Timed out after {self.timeout} seconds."
def execute_completion(
executor: PythonREPL,
completion: str,
return_status: bool = False,
last_code_block: bool = False,
) -> str | Tuple[str, bool]:
# executions = ["!" + code for code in re.findall(r"```bash(.*?)```", completion, re.DOTALL) if "!" not in code]
executions = re.findall(r"```python(.*?)```", completion, re.DOTALL)
if len(executions) == 0: # directly return cot result
return completion, False if return_status else completion
else:
if last_code_block:
executions = [executions[-1]]
# Python
execution_outputs = []
successes = []
for code in executions:
success = False
if "subprocess" in code:
output = "subprocess is not allowed"
execution_outputs.append(output)
successes.append(success)
continue
if "venv" in code:
output = "venv is not allowed"
execution_outputs.append(output)
successes.append(success)
continue
try:
success, output = executor(code)
except TimeoutError as e:
print("time out")
output = e
if not success and not return_status:
output = ""
execution_outputs.append(output)
successes.append(success)
output = str(execution_outputs[-1]).strip()
success = successes[-1]
if return_status:
return output, success
else:
return output
def postprocess_completion(
text: str, return_status: bool = False, last_code_block=False, timeout=5
) -> str | Tuple[str, bool]:
executor = PythonREPL(timeout=timeout)
result = execute_completion(executor, text, return_status=return_status, last_code_block=last_code_block)
del executor
return result
def apply_template(example: Dict[str, Any], prompt: str) -> Dict[str, Any]:
return prompt.format(example["prompt"], "{}")
def last_boxed_only_string(string):
"""
Extracts the last LaTeX boxed or framed expression from a string.
Args:
string (str): The input string containing LaTeX expressions.
Returns:
str or None: The last boxed or framed expression, if found;
otherwise, None.
"""
idx = string.rfind("\\boxed")
if idx < 0:
idx = string.rfind("\\fbox")
if idx < 0:
return None
i = idx
right_brace_idx = None
num_left_braces_open = 0
while i < len(string):
if string[i] == "{":
num_left_braces_open += 1
if string[i] == "}":
num_left_braces_open -= 1
if num_left_braces_open == 0:
right_brace_idx = i
break
i += 1
if right_brace_idx is None:
retval = None
else:
retval = string[idx : right_brace_idx + 1]
return retval
def remove_boxed(s):
"""
Removes the LaTeX boxed command, returning the content inside the braces.
Args:
s (str): The string containing a LaTeX boxed expression.
Returns:
str or None: The content inside the boxed command, if valid;
otherwise, None.
"""
left = "\\boxed{"
try:
assert s[: len(left)] == left
assert s[-1] == "}"
length = len(left)
return s[length:-1]
except Exception:
return None
def extract_boxed_answer(pred_str, strip_double_curly_brace=False):
"""
Extracts the answer from a LaTeX boxed expression within
a prediction string.
Args:
pred_str (str): The string containing one or more LaTeX
boxed expressions.
strip_double_curly_brace (bool): If True, removes an additional
layer of braces.
Returns:
str or None: The extracted answer, if any; otherwise, None.
"""
boxed_str = last_boxed_only_string(pred_str)
if boxed_str is None:
return None
answer = remove_boxed(boxed_str)
if answer is None:
return None
if strip_double_curly_brace:
match = re.match("^\{(.*)\}$", answer) # noqa: W605
if match:
answer = match.group(1)
return answer
def normalize_final_answer(final_answer: str) -> str:
"""
Normalizes a final answer string by removing or replacing various LaTeX
and text elements.
Args:
final_answer (str): The answer string to normalize.
Returns:
str: The normalized answer string.
"""
match = re.search(r"(.*?)Problem:", final_answer, flags=re.S)
if match:
final_answer = match.group(1) # 返回匹配的第一部分,即"Problem"之前的所有文本
"""Normalize a final answer to a quantitative reasoning question."""
# final_answer = final_answer.split('=')[-1]
SUBSTITUTIONS = [
("an ", ""),
("a ", ""),
(".$", "$"),
("\\$", ""),
(r"\ ", ""),
(" ", ""),
("mbox", "text"),
(",\\text{and}", ","),
("\\text{and}", ","),
("\\text{m}", "\\text{}"),
("\\le", "<"),
]
REMOVED_EXPRESSIONS = [
"square",
"ways",
"integers",
"dollars",
"mph",
"inches",
"ft",
"hours",
"km",
"units",
"\\ldots",
"sue",
"points",
"feet",
"minutes",
"digits",
"cents",
"degrees",
"cm",
"gm",
"pounds",
"meters",
"meals",
"edges",
"students",
"childrentickets",
"multiples",
"\\text{s}",
"\\text{.}",
"\\text{\ns}",
"\\text{}^2",
"\\text{}^3",
"\\text{\n}",
"\\text{}",
r"\mathrm{th}",
r"^\circ",
r"^{\circ}",
r"\;",
r",\!",
"{,}",
'"',
"\\dots",
"\n",
"\r",
"\f",
"\%",
]
for before, after in SUBSTITUTIONS:
final_answer = final_answer.replace(before, after)
for expr in REMOVED_EXPRESSIONS:
final_answer = final_answer.replace(expr, "")
# Extract answer that is in LaTeX math, is bold,
# is surrounded by a box, etc.
final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer)
assert "\n" not in final_answer
assert "\r" not in final_answer
assert "\f" not in final_answer
if len(re.findall(r"finalansweris(.*)", final_answer)) > 0:
final_answer = re.findall(r"finalansweris(.*)", final_answer)[-1]
if len(re.findall(r"answer?is:?(.*)", final_answer)) > 0:
final_answer = re.findall(r"answer?is:?(.*)", final_answer)[-1]
if len(re.findall(r"oxed\{(.*?)\}", final_answer)) > 0:
final_answer = re.findall(r"oxed\{(.*?)\}", final_answer)[-1]
if len(re.findall(r"\$(.*?)\$", final_answer)) > 0:
final_answer = re.findall(r"\$(.*?)\$", final_answer)[-1]
final_answer = final_answer.strip()
if "rac" in final_answer and "\\frac" not in final_answer:
final_answer = final_answer.replace("rac", "\\frac")
final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
final_answer = final_answer.replace("$", "")
if final_answer.replace(",", "").isdigit():
final_answer = final_answer.replace(",", "")
return final_answer
def naive_parse(answer: str) -> str:
"""
Extracts and returns the numeric digits from the input string, processing them in reverse order
until a non-numeric character is encountered after encountering the first numeric character.
Args:
answer (str): The input string to parse.
Returns:
str: A string consisting of the numeric digits extracted from the input, in their original order.
Example:
>>> naive_parse("abc123def")
'123'
>>> naive_parse("def456ghi")
'456'
>>> naive_parse("no numbers here")
''
"""
out = []
start = False
end = False
for l in reversed(list(answer)):
if l in "0123456789" and not end:
start = True
out.append(l)
else:
if start:
end = True
out = reversed(out)
return "".join(out)
def validate_answer_is_numeric(x: str | int | float) -> int:
FLOAT_TOLERANCE = 0.2
try:
x = round(float(x))
f = float(x)
if abs(x - f) > FLOAT_TOLERANCE:
x = -1
except Exception:
x = -1
return x
def filter_answers(answers: List[str]) -> List[int]:
formatted_answers = [validate_answer_is_numeric(a) for a in answers]
# Filter for non-negative answers
formatted_answers = [a for a in formatted_answers if a >= 0]
# Compute modulo
formatted_answers = [a % 1_000 for a in formatted_answers]
# less than 2.1 billion or cannot convert to C int (32-bit)
formatted_answers = [a for a in formatted_answers if a <= 999]
return formatted_answers
def check_sympy_equivalence(ref_answer: str, model_answer: str) -> bool:
def do_answers_match(ref_answer: str, model_answer: str) -> bool:
ref_sympy = parse_latex(ref_answer)
model_sympy = parse_latex(model_answer)
diff = simplify(ref_sympy - model_sympy)
return True if -1e-12 < N(diff) < 1e-12 or diff.is_zero else False
try:
result = do_answers_match(ref_answer, model_answer)
return result
except Exception as e:
print(e)
return False
def check_string_match(ref_answer: str, model_answer: str) -> bool:
try:
return ref_answer == model_answer
except Exception as e:
print(e)
return False
def check_answer(ref_answer: str, model_answer: str) -> bool:
# check if strings are the same
correct = check_string_match(ref_answer, model_answer)
if correct:
return True
# use the sympy library to check if the expressions are the same
correct = check_sympy_equivalence(ref_answer, model_answer)
if correct:
return True
return False
debug = False
model_id = "athstral-7B-v0.m1"
revision = "main"
system_prompt = "{}"
validation_set = "kaggle-validation-set-medium"
is_submission = True
num_samples = 4
num_generations = 4
temperature = 0.8
is_quantized = False
restart_on_fail = False
top_p = 1.0
top_k = 0
max_new_tokens = 2048
# Papermill related variables
push_to_hub = False
notebook_name = ""
config = Config(
debug=False,
push_to_hub=False,
model_id=model_id,
revision=revision,
system_prompt=system_prompt,
validation_set=validation_set,
is_quantized=is_quantized,
restart_on_fail=restart_on_fail,
is_submission=is_submission,
num_samples=num_samples,
num_generations=num_generations,
do_sample=True,
temperature=temperature,
top_p=top_p,
top_k=top_k,
max_new_tokens=max_new_tokens
)
print(f"=== Running submission with config ===\n\n{config}")
def parse_data_chunk(data_chunk):
"""
Parse a given data chunk string into a list of individual data entries.
The function splits the input string by the delimiter "data:" and removes any
leading or trailing whitespace from each resulting chunk. Empty chunks are
filtered out from the final list.
Parameters:
data_chunk (str): The input string containing data chunks separated by "data:".
Returns:
list: A list of individual data entries with whitespace stripped.
"""
if isinstance(data_chunk, client.ChatCompletionStreamOutput):
data_chunk = data_chunk.text
chunks = data_chunk.split("data:")
def parse_data_chunk(data_chunk):
"""
Parse a given data chunk string into a list of individual data entries.
The function splits the input string by the delimiter "data:" and removes any
leading or trailing whitespace from each resulting chunk. Empty chunks are
filtered out from the final list.
Parameters:
data_chunk (str): The input string containing data chunks separated by "data:".
Returns:
list: A list of individual data entries with whitespace stripped.
"""
if isinstance(data_chunk, InferenceClient.ChatCompletionStreamOutput): # Update this line if you're using a different client class
data_chunk = data_chunk.text
chunks = data_chunk.split("data:")
for chunk in response:
chunk = chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk
data_chunks = parse_data_chunk(chunk)
try:
for data_chunk in data_chunks:
chunk_json = json.loads(data_chunk)
if "error" in chunk_json and chunk_json["error"]:
yield chunk_json["error"], True
break
delta = chunk_json["choices"][0]["delta"]
content = delta["content"] if "content" in delta else ""
if content != "":
yield content, False
except json.JSONDecodeError as e:
print(f"func: generate error occurred\nchunk:{chunk}\nerror:{e}")
raise e
except KeyError as e:
print(f"func: generate error occurred\nchunk:{chunk}\nerror:{e}")
raise e
def get_majority_text(data):
from collections import Counter
# Count the frequency of each answer in model_answers
answer_counts = Counter(data["model_answers"])
# Find the majority response
majority_response = answer_counts.most_common(1)[0][0]
# Find the index of the first occurrence of the majority response
majority_index = data["model_answers"].index(majority_response)
# Return the corresponding text in gen_texts
return data["gen_texts"][majority_index]
def extract_solution(text):
# Split the text at "### Solution:"
parts = text.split("### Solution:", 1)
if len(parts) > 1:
# Return everything after "### Solution:"
return parts[1].strip()
else:
# Return an empty string if "### Solution:" is not found
return ""
def process_code(
example: Dict[str, Any],
config: Config,
restart_on_fail: bool = False,
last_step: bool = False,
) -> Dict[str, Any]:
gen_text = example["gen_texts"]
num_python_blocks = len(re.findall(r"```python(.*?)```", gen_text, re.DOTALL))
if num_python_blocks == 0:
if restart_on_fail:
print("no code has ever been generated, RESTARTING")
# reset the text to the original
example["gen_texts"] = example["text"]
else:
print("no code has ever been generated, STOP")
example["should_prune"] = True
example["has_code"] = False
return example
if gen_text[-10:] != "```output\n" and ("answer is" in gen_text[-100:] or "\\boxed" in gen_text[-100:]):
num_output_blocks = len(re.findall(r"```output(.*?)```", gen_text, re.DOTALL))
if num_output_blocks == 0:
print("the model hallucinated the code answer")
example["should_prune"] = True
return example
if "boxed" in gen_text[-100:]:
try:
answer = normalize_final_answer(extract_boxed_answer(gen_text[-100:]))
except Exception:
answer = "-1"
else:
answer = normalize_final_answer(gen_text[-100:])
example["model_answers"] = answer
if not config.is_submission:
example["corrects"] = check_answer(example["ground_truth"], answer)
example["should_prune"] = True
print("Answer is: ", answer, example["ground_truth"], example["corrects"])
return example
if last_step:
# no point in continuing if we are at the last step
return example
if gen_text[-10:] != "```output\n":
# something else has gone wrong with the generation
print("warning: output block not found: ", gen_text[-40:])
if restart_on_fail:
example["gen_texts"] = example["text"]
else:
example["should_prune"] = True
return example
code_result, status = postprocess_completion(gen_text, return_status=True, last_code_block=True)
# add the code result for the next round of generation
TRUNCATION_LIMIT = 200
if len(code_result) > TRUNCATION_LIMIT:
code_result = code_result[:TRUNCATION_LIMIT] + " ... (output truncated)"
example["gen_texts"] = gen_text + f"{code_result}\n```"
return example
def solve_problem(problem, temperature, progress=gr.Progress()):
"""
yield token: string, stop: bool
"""
problem = apply_template({"prompt": problem}, prompt=config.system_prompt)
print(f"Problem: {problem}")
sample = {
"problem": problem, # not used for the submission TODO Remove
"ground_truth": "unknown", # not used for the submission TODO Remove
"text": "## Solution:\n",
"gen_texts": "", # used to store all the generated text
"should_prune": False,
"problem_index": -1, # not used for the submission TODO Remove
"model_answers": "-1",
"has_code": True,
"corrects": False, # not used for the submission TODO Remove
}
for step in progress.tqdm(
range(config.num_generations), desc="Generating candidates"
): # Depth of the tree (e.g. 6 steps = 5 code blocks)
step_reponse = sample["gen_texts"]
messages = [
{"role": "user", "content": sample["problem"]},
{"role": "assistant", "content": sample["gen_texts"]},
]
stop = False
for reponse_message, error in generate(messages, temperature):
if reponse_message is not None:
step_reponse += reponse_message
yield step_reponse, False
if error:
stop = True
sample["gen_texts"] = step_reponse
# TODO: Maybe it should just return the result of running the code
sample = process_code(
sample,
config=config,
restart_on_fail=config.restart_on_fail,
last_step=(step == (config.num_generations - 1)),
)
sample["gen_texts"] = sample["gen_texts"] + "\n"
run_code_reponse = sample["gen_texts"].replace(step_reponse, "")
for output_mseeage in run_code_reponse:
if output_mseeage is not None:
step_reponse += output_mseeage
yield step_reponse, False
if sample["should_prune"] or stop:
break
yield sample["gen_texts"], True
features = Features({
'id': Value('int64'),
'problem': Value('string'),
'answer': Value('string'),
#'prompt': Value('string'), # Ensure this matches the actual data type of 'prompt' in your dataset
#'level': Value('string')
})
# Now load the dataset using the defined schema
example_data = datasets.load_dataset(
"AI-MO/aimo-validation-math-level-5",
split="train",
use_auth_token=os.environ.get("HF_DATASET_TOKEN", None),
features=features # Pass the schema definition here
)
with open( "app.css", "r") as f:
css = f.read()
latex_delimiters = [
{"left": "[", "right": "]", "display": True},
]
def get_random_problem():
example = random.choice(list(example_data))
problem = example["problem"]
return problem
def update_example_problem():
problem_example_text = get_random_problem()
return problem_example_text, problem_example_text
def clear():
problem_example_text = get_random_problem()
return "", 0.1, "", problem_example_text, problem_example_text
def preprocess_output(text):
return text.replace(r"\(", r"\\(").replace(r"\)", r"\\)")
with gr.Blocks(css=css, title="Math Olympiad Solver") as demo:
btn_list = []
problem_input_ele_list = []
problem_example_text = get_random_problem()
with gr.Row(elem_classes="title"):
gr.HTML("Math Olympiad Solver", elem_classes="title-content")
with gr.Row(elem_classes="sub-title"):
gr.HTML(
"<div>Demo of the maths solving with AI Models</a>. Example data are drawn randomly generated.</div>",
elem_classes="sub-title-content",
)
with gr.Row(elem_classes="main-area"):
with gr.Column(scale=1, elem_classes="left"):
with gr.Row(elem_classes="probelm-example-container"):
with gr.Blocks(elem_classes="probelm-example-title"):
gr.HTML("Problem example", elem_classes="probelm-example-title-content")
with gr.Blocks(elem_classes="action-container"):
another_btn = gr.Button(
"",
elem_classes="probelm-example-another",
icon="./static/images/reset.png",
)
copy_btn = gr.Button("Copy", elem_classes="probelm-example-copy")
problem_example = gr.HTML(
problem_example_text,
elem_classes="probelm-example-content",
)
with gr.Row(elem_classes="probelm-input-container"):
inp = gr.Textbox(placeholder="Problem", label="Problem input", lines=5, visible=True)
problem_markdown = gr.Markdown(
visible=False,
latex_delimiters=[
{"left": "[", "right": "]", "display": True},
{"left": "$", "right": "$", "display": False},
{"left": r"\(", "right": r"\)", "display": False},
],
)
inp.change(fn=lambda text: text, inputs=[inp], outputs=[problem_markdown])
problem_input_ele_list.append(inp)
problem_input_ele_list.append(problem_markdown)
with gr.Accordion("Advanced Options", open=False):
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.1, step=0.1, label="Temperature")
with gr.Row() as btn_area:
btn_clear = gr.Button("Clear", elem_classes="clear-btn")
btn_run = gr.Button("Run", elem_classes="run-btn")
btn_list.append(btn_clear)
btn_list.append(btn_run)
with gr.Column(scale=1, elem_classes="right"):
gr.HTML("Solution", elem_classes="solution-title-content")
out = gr.Markdown(
elem_classes="solution-content",
latex_delimiters=[
{"left": "[", "right": "]", "display": True},
{"left": "$", "right": "$", "display": False},
{"left": r"\(", "right": r"\)", "display": False},
],
)
problem_example_text_hidden = gr.Markdown(value=problem_example_text, visible=False)
def solve_problem_wrapper(inp_text, temperature):
new_running_btn = gr.Button("", elem_classes="run-btn running-btn")
try:
for after_tokens, stop in solve_problem(inp_text, temperature):
yield preprocess_output(after_tokens), new_running_btn
if stop:
btn_run = gr.Button("Run", elem_classes="run-btn")
yield preprocess_output(after_tokens), btn_run
except Exception as e:
raise e
def mount_run_btn(btn):
btn.click(fn=solve_problem_wrapper, inputs=[inp, temperature], outputs=[out, btn_list[1]])
btn.click(get_run_after_problem_input, None, outputs=problem_input_ele_list)
def get_run_after_problem_input():
return gr.Textbox(placeholder="Problem", label="Problem input", lines=5, visible=False), gr.Markdown(
visible=True,
latex_delimiters=[
{"left": "[", "right": "]", "display": True},
{"left": "$", "right": "$", "display": False},
],
elem_classes="problem-input-markdown",
)
def get_init_problem_input():
return gr.Textbox(placeholder="Problem", label="Problem input", lines=5, visible=True), gr.Markdown(
visible=False,
latex_delimiters=[
{"left": "[", "right": "]", "display": True},
{"left": "$", "right": "$", "display": False},
],
)
copy_btn.click(fn=lambda example: example, inputs=[problem_example_text_hidden], outputs=[inp])
btn_clear.click(
fn=clear,
inputs=[],
outputs=[
inp,
temperature,
out,
problem_example,
problem_example_text_hidden,
],
)
btn_clear.click(get_init_problem_input, None, outputs=problem_input_ele_list)
mount_run_btn(btn_run)
demo.load(
update_example_problem,
inputs=None,
outputs=[
problem_example,
problem_example_text_hidden,
],
)
another_btn.click(
fn=update_example_problem,
inputs=[],
outputs=[
problem_example,
problem_example_text_hidden,
],
)
if __name__ == "__main__":
demo.queue(default_concurrency_limit=5).launch(share=True)