|
import gradio as gr |
|
|
|
from dataclasses import dataclass |
|
from concurrent.futures import ThreadPoolExecutor, TimeoutError |
|
from huggingface_hub import InferenceClient |
|
import os |
|
import re |
|
import subprocess |
|
import tempfile |
|
import json |
|
import datasets |
|
from datasets import load_dataset |
|
from datasets import Value, Features |
|
import random |
|
import time |
|
from typing import Tuple, Dict, Any, List |
|
from sympy import N, simplify |
|
from sympy.parsing.latex import parse_latex |
|
|
|
import base64 |
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
from transformers import AutoTokenizer, AutoModelForPreTraining |
|
from langchain_community.llms.manifest import ManifestWrapper |
|
|
|
|
|
|
|
|
|
|
|
|
|
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta") |
|
|
|
|
|
|
|
@dataclass |
|
class Config: |
|
debug: bool = False |
|
push_to_hub: bool = False |
|
model_id: str = None |
|
revision: str = None |
|
system_prompt: str = None |
|
validation_set: str = None |
|
is_quantized: bool = False |
|
restart_on_fail: bool = False |
|
is_submission: bool = False |
|
num_samples: int = 1 |
|
num_generations: int = 1 |
|
do_sample: bool = True |
|
temperature: float = 1.0 |
|
top_p: float = 0.9 |
|
top_k: int = 50 |
|
max_new_tokens: int = 100 |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("AnReu/math_pretrained_bert") |
|
model = AutoModelForPreTraining.from_pretrained("AnReu/math_pretrained_bert") |
|
|
|
class PythonREPL: |
|
def __init__(self, timeout=5): |
|
self.timeout = timeout |
|
|
|
def execute(self, query: str) -> Tuple[bool, str]: |
|
query = "import math\nimport numpy as np\nimport sympy as sp\n" + query |
|
query = query.strip().split("\n") |
|
if "print(" not in query[-1]: |
|
if "#" in query[-1]: |
|
query[-1] = query[-1].split("#")[0] |
|
query[-1] = "print(" + query[-1] + ")" |
|
query = "\n".join(query) |
|
|
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
temp_file_path = os.path.join(temp_dir, "tmp.py") |
|
|
|
with open(temp_file_path, "w") as f: |
|
f.write(query) |
|
|
|
result = subprocess.run( |
|
["python3", temp_file_path], |
|
capture_output=True, |
|
check=False, |
|
text=True, |
|
timeout=self.timeout, |
|
) |
|
|
|
if result.returncode == 0: |
|
output = result.stdout |
|
return True, output.strip() |
|
else: |
|
error_msg = result.stderr.strip() |
|
msgs = error_msg.split("\n") |
|
new_msgs = [] |
|
want_next = False |
|
for m in msgs: |
|
if "Traceback" in m: |
|
new_msgs.append(m) |
|
elif m == msgs[-1]: |
|
new_msgs.append(m) |
|
elif temp_file_path in m: |
|
st = m.index('"/') + 1 if '"/' in m else 0 |
|
ed = m.index(temp_file_path) + 1 if temp_file_path in m else None |
|
clr = m[st:ed] if not ed else m[st:] |
|
m = m.replace(clr, "") |
|
new_msgs.append(m) |
|
want_next = True |
|
elif want_next: |
|
new_msgs.append(m) |
|
want_next = False |
|
error_msg = "\n".join(new_msgs) |
|
return False, error_msg.strip() |
|
|
|
def __call__(self, query: str) -> Tuple[bool, str]: |
|
with ThreadPoolExecutor() as executor: |
|
future = executor.submit(self.execute, query) |
|
try: |
|
return future.result(timeout=self.timeout) |
|
except TimeoutError: |
|
return False, f"Timed out after {self.timeout} seconds." |
|
|
|
|
|
def execute_completion( |
|
executor: PythonREPL, |
|
completion: str, |
|
return_status: bool = False, |
|
last_code_block: bool = False, |
|
) -> str | Tuple[str, bool]: |
|
|
|
executions = re.findall(r"```python(.*?)```", completion, re.DOTALL) |
|
|
|
if len(executions) == 0: |
|
return completion, False if return_status else completion |
|
else: |
|
if last_code_block: |
|
executions = [executions[-1]] |
|
|
|
|
|
execution_outputs = [] |
|
successes = [] |
|
for code in executions: |
|
success = False |
|
|
|
if "subprocess" in code: |
|
output = "subprocess is not allowed" |
|
execution_outputs.append(output) |
|
successes.append(success) |
|
continue |
|
|
|
if "venv" in code: |
|
output = "venv is not allowed" |
|
execution_outputs.append(output) |
|
successes.append(success) |
|
continue |
|
|
|
try: |
|
success, output = executor(code) |
|
except TimeoutError as e: |
|
print("time out") |
|
output = e |
|
|
|
if not success and not return_status: |
|
output = "" |
|
|
|
execution_outputs.append(output) |
|
successes.append(success) |
|
|
|
output = str(execution_outputs[-1]).strip() |
|
success = successes[-1] |
|
|
|
if return_status: |
|
return output, success |
|
else: |
|
return output |
|
|
|
|
|
def postprocess_completion( |
|
text: str, return_status: bool = False, last_code_block=False, timeout=5 |
|
) -> str | Tuple[str, bool]: |
|
executor = PythonREPL(timeout=timeout) |
|
|
|
result = execute_completion(executor, text, return_status=return_status, last_code_block=last_code_block) |
|
del executor |
|
|
|
return result |
|
|
|
|
|
def apply_template(example: Dict[str, Any], prompt: str) -> Dict[str, Any]: |
|
return prompt.format(example["prompt"], "{}") |
|
|
|
|
|
def last_boxed_only_string(string): |
|
""" |
|
Extracts the last LaTeX boxed or framed expression from a string. |
|
Args: |
|
string (str): The input string containing LaTeX expressions. |
|
Returns: |
|
str or None: The last boxed or framed expression, if found; |
|
otherwise, None. |
|
""" |
|
|
|
idx = string.rfind("\\boxed") |
|
if idx < 0: |
|
idx = string.rfind("\\fbox") |
|
if idx < 0: |
|
return None |
|
|
|
i = idx |
|
right_brace_idx = None |
|
num_left_braces_open = 0 |
|
while i < len(string): |
|
if string[i] == "{": |
|
num_left_braces_open += 1 |
|
if string[i] == "}": |
|
num_left_braces_open -= 1 |
|
if num_left_braces_open == 0: |
|
right_brace_idx = i |
|
break |
|
i += 1 |
|
|
|
if right_brace_idx is None: |
|
retval = None |
|
else: |
|
retval = string[idx : right_brace_idx + 1] |
|
|
|
return retval |
|
|
|
|
|
def remove_boxed(s): |
|
""" |
|
Removes the LaTeX boxed command, returning the content inside the braces. |
|
Args: |
|
s (str): The string containing a LaTeX boxed expression. |
|
Returns: |
|
str or None: The content inside the boxed command, if valid; |
|
otherwise, None. |
|
""" |
|
|
|
left = "\\boxed{" |
|
try: |
|
assert s[: len(left)] == left |
|
assert s[-1] == "}" |
|
length = len(left) |
|
return s[length:-1] |
|
except Exception: |
|
return None |
|
|
|
|
|
def extract_boxed_answer(pred_str, strip_double_curly_brace=False): |
|
""" |
|
Extracts the answer from a LaTeX boxed expression within |
|
a prediction string. |
|
Args: |
|
pred_str (str): The string containing one or more LaTeX |
|
boxed expressions. |
|
strip_double_curly_brace (bool): If True, removes an additional |
|
layer of braces. |
|
Returns: |
|
str or None: The extracted answer, if any; otherwise, None. |
|
""" |
|
|
|
boxed_str = last_boxed_only_string(pred_str) |
|
if boxed_str is None: |
|
return None |
|
answer = remove_boxed(boxed_str) |
|
if answer is None: |
|
return None |
|
if strip_double_curly_brace: |
|
match = re.match("^\{(.*)\}$", answer) |
|
if match: |
|
answer = match.group(1) |
|
return answer |
|
|
|
|
|
def normalize_final_answer(final_answer: str) -> str: |
|
""" |
|
Normalizes a final answer string by removing or replacing various LaTeX |
|
and text elements. |
|
Args: |
|
final_answer (str): The answer string to normalize. |
|
Returns: |
|
str: The normalized answer string. |
|
""" |
|
|
|
match = re.search(r"(.*?)Problem:", final_answer, flags=re.S) |
|
if match: |
|
final_answer = match.group(1) |
|
"""Normalize a final answer to a quantitative reasoning question.""" |
|
|
|
SUBSTITUTIONS = [ |
|
("an ", ""), |
|
("a ", ""), |
|
(".$", "$"), |
|
("\\$", ""), |
|
(r"\ ", ""), |
|
(" ", ""), |
|
("mbox", "text"), |
|
(",\\text{and}", ","), |
|
("\\text{and}", ","), |
|
("\\text{m}", "\\text{}"), |
|
("\\le", "<"), |
|
] |
|
REMOVED_EXPRESSIONS = [ |
|
"square", |
|
"ways", |
|
"integers", |
|
"dollars", |
|
"mph", |
|
"inches", |
|
"ft", |
|
"hours", |
|
"km", |
|
"units", |
|
"\\ldots", |
|
"sue", |
|
"points", |
|
"feet", |
|
"minutes", |
|
"digits", |
|
"cents", |
|
"degrees", |
|
"cm", |
|
"gm", |
|
"pounds", |
|
"meters", |
|
"meals", |
|
"edges", |
|
"students", |
|
"childrentickets", |
|
"multiples", |
|
"\\text{s}", |
|
"\\text{.}", |
|
"\\text{\ns}", |
|
"\\text{}^2", |
|
"\\text{}^3", |
|
"\\text{\n}", |
|
"\\text{}", |
|
r"\mathrm{th}", |
|
r"^\circ", |
|
r"^{\circ}", |
|
r"\;", |
|
r",\!", |
|
"{,}", |
|
'"', |
|
"\\dots", |
|
"\n", |
|
"\r", |
|
"\f", |
|
"\%", |
|
] |
|
for before, after in SUBSTITUTIONS: |
|
final_answer = final_answer.replace(before, after) |
|
for expr in REMOVED_EXPRESSIONS: |
|
final_answer = final_answer.replace(expr, "") |
|
|
|
|
|
|
|
final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) |
|
final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) |
|
final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) |
|
final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) |
|
assert "\n" not in final_answer |
|
assert "\r" not in final_answer |
|
assert "\f" not in final_answer |
|
if len(re.findall(r"finalansweris(.*)", final_answer)) > 0: |
|
final_answer = re.findall(r"finalansweris(.*)", final_answer)[-1] |
|
|
|
if len(re.findall(r"answer?is:?(.*)", final_answer)) > 0: |
|
final_answer = re.findall(r"answer?is:?(.*)", final_answer)[-1] |
|
|
|
if len(re.findall(r"oxed\{(.*?)\}", final_answer)) > 0: |
|
final_answer = re.findall(r"oxed\{(.*?)\}", final_answer)[-1] |
|
|
|
if len(re.findall(r"\$(.*?)\$", final_answer)) > 0: |
|
final_answer = re.findall(r"\$(.*?)\$", final_answer)[-1] |
|
final_answer = final_answer.strip() |
|
if "rac" in final_answer and "\\frac" not in final_answer: |
|
final_answer = final_answer.replace("rac", "\\frac") |
|
|
|
final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) |
|
final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) |
|
final_answer = final_answer.replace("$", "") |
|
|
|
if final_answer.replace(",", "").isdigit(): |
|
final_answer = final_answer.replace(",", "") |
|
|
|
return final_answer |
|
|
|
|
|
def naive_parse(answer: str) -> str: |
|
""" |
|
Extracts and returns the numeric digits from the input string, processing them in reverse order |
|
until a non-numeric character is encountered after encountering the first numeric character. |
|
|
|
Args: |
|
answer (str): The input string to parse. |
|
|
|
Returns: |
|
str: A string consisting of the numeric digits extracted from the input, in their original order. |
|
|
|
Example: |
|
>>> naive_parse("abc123def") |
|
'123' |
|
>>> naive_parse("def456ghi") |
|
'456' |
|
>>> naive_parse("no numbers here") |
|
'' |
|
""" |
|
out = [] |
|
start = False |
|
end = False |
|
for l in reversed(list(answer)): |
|
if l in "0123456789" and not end: |
|
start = True |
|
out.append(l) |
|
else: |
|
if start: |
|
end = True |
|
|
|
out = reversed(out) |
|
return "".join(out) |
|
|
|
|
|
def validate_answer_is_numeric(x: str | int | float) -> int: |
|
FLOAT_TOLERANCE = 0.2 |
|
try: |
|
x = round(float(x)) |
|
f = float(x) |
|
if abs(x - f) > FLOAT_TOLERANCE: |
|
x = -1 |
|
except Exception: |
|
x = -1 |
|
return x |
|
|
|
|
|
def filter_answers(answers: List[str]) -> List[int]: |
|
formatted_answers = [validate_answer_is_numeric(a) for a in answers] |
|
|
|
|
|
formatted_answers = [a for a in formatted_answers if a >= 0] |
|
|
|
formatted_answers = [a % 1_000 for a in formatted_answers] |
|
|
|
formatted_answers = [a for a in formatted_answers if a <= 999] |
|
return formatted_answers |
|
|
|
|
|
def check_sympy_equivalence(ref_answer: str, model_answer: str) -> bool: |
|
def do_answers_match(ref_answer: str, model_answer: str) -> bool: |
|
ref_sympy = parse_latex(ref_answer) |
|
model_sympy = parse_latex(model_answer) |
|
diff = simplify(ref_sympy - model_sympy) |
|
return True if -1e-12 < N(diff) < 1e-12 or diff.is_zero else False |
|
|
|
try: |
|
result = do_answers_match(ref_answer, model_answer) |
|
return result |
|
except Exception as e: |
|
print(e) |
|
return False |
|
|
|
|
|
def check_string_match(ref_answer: str, model_answer: str) -> bool: |
|
try: |
|
return ref_answer == model_answer |
|
except Exception as e: |
|
print(e) |
|
return False |
|
|
|
|
|
def check_answer(ref_answer: str, model_answer: str) -> bool: |
|
|
|
correct = check_string_match(ref_answer, model_answer) |
|
if correct: |
|
return True |
|
|
|
|
|
correct = check_sympy_equivalence(ref_answer, model_answer) |
|
if correct: |
|
return True |
|
|
|
return False |
|
|
|
|
|
debug = False |
|
model_id = "athstral-7B-v0.m1" |
|
revision = "main" |
|
system_prompt = "{}" |
|
validation_set = "kaggle-validation-set-medium" |
|
is_submission = True |
|
num_samples = 4 |
|
num_generations = 4 |
|
temperature = 0.8 |
|
is_quantized = False |
|
restart_on_fail = False |
|
top_p = 1.0 |
|
top_k = 0 |
|
max_new_tokens = 2048 |
|
|
|
push_to_hub = False |
|
notebook_name = "" |
|
|
|
config = Config( |
|
debug=False, |
|
push_to_hub=False, |
|
model_id=model_id, |
|
revision=revision, |
|
system_prompt=system_prompt, |
|
validation_set=validation_set, |
|
is_quantized=is_quantized, |
|
restart_on_fail=restart_on_fail, |
|
is_submission=is_submission, |
|
num_samples=num_samples, |
|
num_generations=num_generations, |
|
do_sample=True, |
|
temperature=temperature, |
|
top_p=top_p, |
|
top_k=top_k, |
|
max_new_tokens=max_new_tokens |
|
) |
|
|
|
|
|
print(f"=== Running submission with config ===\n\n{config}") |
|
|
|
|
|
def parse_data_chunk(data_chunk): |
|
""" |
|
Parse a given data chunk string into a list of individual data entries. |
|
|
|
The function splits the input string by the delimiter "data:" and removes any |
|
leading or trailing whitespace from each resulting chunk. Empty chunks are |
|
filtered out from the final list. |
|
|
|
Parameters: |
|
data_chunk (str): The input string containing data chunks separated by "data:". |
|
|
|
Returns: |
|
list: A list of individual data entries with whitespace stripped. |
|
""" |
|
if isinstance(data_chunk, client.ChatCompletionStreamOutput): |
|
data_chunk = data_chunk.text |
|
chunks = data_chunk.split("data:") |
|
|
|
def parse_data_chunk(data_chunk): |
|
""" |
|
Parse a given data chunk string into a list of individual data entries. |
|
The function splits the input string by the delimiter "data:" and removes any |
|
leading or trailing whitespace from each resulting chunk. Empty chunks are |
|
filtered out from the final list. |
|
Parameters: |
|
data_chunk (str): The input string containing data chunks separated by "data:". |
|
Returns: |
|
list: A list of individual data entries with whitespace stripped. |
|
""" |
|
if isinstance(data_chunk, InferenceClient.ChatCompletionStreamOutput): |
|
data_chunk = data_chunk.text |
|
chunks = data_chunk.split("data:") |
|
|
|
for chunk in response: |
|
chunk = chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk |
|
data_chunks = parse_data_chunk(chunk) |
|
try: |
|
for data_chunk in data_chunks: |
|
chunk_json = json.loads(data_chunk) |
|
if "error" in chunk_json and chunk_json["error"]: |
|
yield chunk_json["error"], True |
|
break |
|
delta = chunk_json["choices"][0]["delta"] |
|
content = delta["content"] if "content" in delta else "" |
|
if content != "": |
|
yield content, False |
|
except json.JSONDecodeError as e: |
|
print(f"func: generate error occurred\nchunk:{chunk}\nerror:{e}") |
|
raise e |
|
except KeyError as e: |
|
print(f"func: generate error occurred\nchunk:{chunk}\nerror:{e}") |
|
raise e |
|
|
|
def get_majority_text(data): |
|
from collections import Counter |
|
|
|
|
|
answer_counts = Counter(data["model_answers"]) |
|
|
|
|
|
majority_response = answer_counts.most_common(1)[0][0] |
|
|
|
|
|
majority_index = data["model_answers"].index(majority_response) |
|
|
|
|
|
return data["gen_texts"][majority_index] |
|
|
|
|
|
def extract_solution(text): |
|
|
|
parts = text.split("### Solution:", 1) |
|
if len(parts) > 1: |
|
|
|
return parts[1].strip() |
|
else: |
|
|
|
return "" |
|
|
|
|
|
def process_code( |
|
example: Dict[str, Any], |
|
config: Config, |
|
restart_on_fail: bool = False, |
|
last_step: bool = False, |
|
) -> Dict[str, Any]: |
|
gen_text = example["gen_texts"] |
|
num_python_blocks = len(re.findall(r"```python(.*?)```", gen_text, re.DOTALL)) |
|
|
|
if num_python_blocks == 0: |
|
if restart_on_fail: |
|
print("no code has ever been generated, RESTARTING") |
|
|
|
example["gen_texts"] = example["text"] |
|
else: |
|
print("no code has ever been generated, STOP") |
|
example["should_prune"] = True |
|
example["has_code"] = False |
|
return example |
|
|
|
if gen_text[-10:] != "```output\n" and ("answer is" in gen_text[-100:] or "\\boxed" in gen_text[-100:]): |
|
num_output_blocks = len(re.findall(r"```output(.*?)```", gen_text, re.DOTALL)) |
|
if num_output_blocks == 0: |
|
print("the model hallucinated the code answer") |
|
example["should_prune"] = True |
|
return example |
|
|
|
if "boxed" in gen_text[-100:]: |
|
try: |
|
answer = normalize_final_answer(extract_boxed_answer(gen_text[-100:])) |
|
except Exception: |
|
answer = "-1" |
|
else: |
|
answer = normalize_final_answer(gen_text[-100:]) |
|
|
|
example["model_answers"] = answer |
|
if not config.is_submission: |
|
example["corrects"] = check_answer(example["ground_truth"], answer) |
|
example["should_prune"] = True |
|
print("Answer is: ", answer, example["ground_truth"], example["corrects"]) |
|
return example |
|
|
|
if last_step: |
|
|
|
return example |
|
|
|
if gen_text[-10:] != "```output\n": |
|
|
|
print("warning: output block not found: ", gen_text[-40:]) |
|
if restart_on_fail: |
|
example["gen_texts"] = example["text"] |
|
else: |
|
example["should_prune"] = True |
|
return example |
|
|
|
code_result, status = postprocess_completion(gen_text, return_status=True, last_code_block=True) |
|
|
|
TRUNCATION_LIMIT = 200 |
|
if len(code_result) > TRUNCATION_LIMIT: |
|
code_result = code_result[:TRUNCATION_LIMIT] + " ... (output truncated)" |
|
example["gen_texts"] = gen_text + f"{code_result}\n```" |
|
|
|
return example |
|
|
|
|
|
def solve_problem(problem, temperature, progress=gr.Progress()): |
|
""" |
|
yield token: string, stop: bool |
|
""" |
|
problem = apply_template({"prompt": problem}, prompt=config.system_prompt) |
|
print(f"Problem: {problem}") |
|
|
|
sample = { |
|
"problem": problem, |
|
"ground_truth": "unknown", |
|
"text": "## Solution:\n", |
|
"gen_texts": "", |
|
"should_prune": False, |
|
"problem_index": -1, |
|
"model_answers": "-1", |
|
"has_code": True, |
|
"corrects": False, |
|
} |
|
|
|
for step in progress.tqdm( |
|
range(config.num_generations), desc="Generating candidates" |
|
): |
|
|
|
step_reponse = sample["gen_texts"] |
|
|
|
messages = [ |
|
{"role": "user", "content": sample["problem"]}, |
|
{"role": "assistant", "content": sample["gen_texts"]}, |
|
] |
|
|
|
stop = False |
|
|
|
for reponse_message, error in generate(messages, temperature): |
|
if reponse_message is not None: |
|
step_reponse += reponse_message |
|
yield step_reponse, False |
|
|
|
if error: |
|
stop = True |
|
|
|
sample["gen_texts"] = step_reponse |
|
|
|
|
|
sample = process_code( |
|
sample, |
|
config=config, |
|
restart_on_fail=config.restart_on_fail, |
|
last_step=(step == (config.num_generations - 1)), |
|
) |
|
sample["gen_texts"] = sample["gen_texts"] + "\n" |
|
|
|
run_code_reponse = sample["gen_texts"].replace(step_reponse, "") |
|
|
|
for output_mseeage in run_code_reponse: |
|
if output_mseeage is not None: |
|
step_reponse += output_mseeage |
|
yield step_reponse, False |
|
|
|
if sample["should_prune"] or stop: |
|
break |
|
|
|
yield sample["gen_texts"], True |
|
|
|
features = Features({ |
|
'id': Value('int64'), |
|
'problem': Value('string'), |
|
'answer': Value('string'), |
|
|
|
|
|
}) |
|
|
|
|
|
example_data = datasets.load_dataset( |
|
"AI-MO/aimo-validation-math-level-5", |
|
split="train", |
|
use_auth_token=os.environ.get("HF_DATASET_TOKEN", None), |
|
features=features |
|
) |
|
|
|
|
|
|
|
with open( "app.css", "r") as f: |
|
css = f.read() |
|
|
|
|
|
latex_delimiters = [ |
|
{"left": "[", "right": "]", "display": True}, |
|
] |
|
|
|
|
|
def get_random_problem(): |
|
example = random.choice(list(example_data)) |
|
problem = example["problem"] |
|
return problem |
|
|
|
|
|
def update_example_problem(): |
|
problem_example_text = get_random_problem() |
|
return problem_example_text, problem_example_text |
|
|
|
|
|
def clear(): |
|
problem_example_text = get_random_problem() |
|
return "", 0.1, "", problem_example_text, problem_example_text |
|
|
|
|
|
def preprocess_output(text): |
|
return text.replace(r"\(", r"\\(").replace(r"\)", r"\\)") |
|
|
|
|
|
with gr.Blocks(css=css, title="Math Olympiad Solver") as demo: |
|
btn_list = [] |
|
problem_input_ele_list = [] |
|
|
|
problem_example_text = get_random_problem() |
|
|
|
with gr.Row(elem_classes="title"): |
|
gr.HTML("Math Olympiad Solver", elem_classes="title-content") |
|
|
|
with gr.Row(elem_classes="sub-title"): |
|
gr.HTML( |
|
"<div>Demo of the maths solving with AI Models</a>. Example data are drawn randomly generated.</div>", |
|
elem_classes="sub-title-content", |
|
) |
|
|
|
with gr.Row(elem_classes="main-area"): |
|
with gr.Column(scale=1, elem_classes="left"): |
|
with gr.Row(elem_classes="probelm-example-container"): |
|
with gr.Blocks(elem_classes="probelm-example-title"): |
|
gr.HTML("Problem example", elem_classes="probelm-example-title-content") |
|
|
|
with gr.Blocks(elem_classes="action-container"): |
|
another_btn = gr.Button( |
|
"", |
|
elem_classes="probelm-example-another", |
|
icon="./static/images/reset.png", |
|
) |
|
copy_btn = gr.Button("Copy", elem_classes="probelm-example-copy") |
|
|
|
problem_example = gr.HTML( |
|
problem_example_text, |
|
elem_classes="probelm-example-content", |
|
) |
|
|
|
with gr.Row(elem_classes="probelm-input-container"): |
|
inp = gr.Textbox(placeholder="Problem", label="Problem input", lines=5, visible=True) |
|
problem_markdown = gr.Markdown( |
|
visible=False, |
|
latex_delimiters=[ |
|
{"left": "[", "right": "]", "display": True}, |
|
{"left": "$", "right": "$", "display": False}, |
|
{"left": r"\(", "right": r"\)", "display": False}, |
|
], |
|
) |
|
|
|
inp.change(fn=lambda text: text, inputs=[inp], outputs=[problem_markdown]) |
|
problem_input_ele_list.append(inp) |
|
problem_input_ele_list.append(problem_markdown) |
|
|
|
with gr.Accordion("Advanced Options", open=False): |
|
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.1, step=0.1, label="Temperature") |
|
|
|
with gr.Row() as btn_area: |
|
btn_clear = gr.Button("Clear", elem_classes="clear-btn") |
|
btn_run = gr.Button("Run", elem_classes="run-btn") |
|
btn_list.append(btn_clear) |
|
btn_list.append(btn_run) |
|
|
|
with gr.Column(scale=1, elem_classes="right"): |
|
gr.HTML("Solution", elem_classes="solution-title-content") |
|
out = gr.Markdown( |
|
elem_classes="solution-content", |
|
latex_delimiters=[ |
|
{"left": "[", "right": "]", "display": True}, |
|
{"left": "$", "right": "$", "display": False}, |
|
{"left": r"\(", "right": r"\)", "display": False}, |
|
], |
|
) |
|
|
|
problem_example_text_hidden = gr.Markdown(value=problem_example_text, visible=False) |
|
|
|
def solve_problem_wrapper(inp_text, temperature): |
|
new_running_btn = gr.Button("", elem_classes="run-btn running-btn") |
|
|
|
try: |
|
for after_tokens, stop in solve_problem(inp_text, temperature): |
|
yield preprocess_output(after_tokens), new_running_btn |
|
|
|
if stop: |
|
btn_run = gr.Button("Run", elem_classes="run-btn") |
|
yield preprocess_output(after_tokens), btn_run |
|
|
|
except Exception as e: |
|
raise e |
|
|
|
def mount_run_btn(btn): |
|
btn.click(fn=solve_problem_wrapper, inputs=[inp, temperature], outputs=[out, btn_list[1]]) |
|
btn.click(get_run_after_problem_input, None, outputs=problem_input_ele_list) |
|
|
|
def get_run_after_problem_input(): |
|
return gr.Textbox(placeholder="Problem", label="Problem input", lines=5, visible=False), gr.Markdown( |
|
visible=True, |
|
latex_delimiters=[ |
|
{"left": "[", "right": "]", "display": True}, |
|
{"left": "$", "right": "$", "display": False}, |
|
], |
|
elem_classes="problem-input-markdown", |
|
) |
|
|
|
def get_init_problem_input(): |
|
return gr.Textbox(placeholder="Problem", label="Problem input", lines=5, visible=True), gr.Markdown( |
|
visible=False, |
|
latex_delimiters=[ |
|
{"left": "[", "right": "]", "display": True}, |
|
{"left": "$", "right": "$", "display": False}, |
|
], |
|
) |
|
|
|
copy_btn.click(fn=lambda example: example, inputs=[problem_example_text_hidden], outputs=[inp]) |
|
|
|
btn_clear.click( |
|
fn=clear, |
|
inputs=[], |
|
outputs=[ |
|
inp, |
|
temperature, |
|
out, |
|
problem_example, |
|
problem_example_text_hidden, |
|
], |
|
) |
|
|
|
btn_clear.click(get_init_problem_input, None, outputs=problem_input_ele_list) |
|
|
|
mount_run_btn(btn_run) |
|
|
|
demo.load( |
|
update_example_problem, |
|
inputs=None, |
|
outputs=[ |
|
problem_example, |
|
problem_example_text_hidden, |
|
], |
|
) |
|
|
|
another_btn.click( |
|
fn=update_example_problem, |
|
inputs=[], |
|
outputs=[ |
|
problem_example, |
|
problem_example_text_hidden, |
|
], |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.queue(default_concurrency_limit=5).launch(share=True) |
|
|