"""
This script is adapted from Qwen2.5-Math
https://github.com/QwenLM/Qwen2.5-Math/blob/main/evaluation/grader.py
"""

import re
import regex
import multiprocessing
from math import isclose
from typing import Union
from collections import defaultdict

from sympy import simplify, N
from sympy.parsing.sympy_parser import parse_expr
from sympy.parsing.latex import parse_latex


def latex2sympy(sympy: str, variable_values={}):
    # record frac
    global frac_type
    if sympy.find(r'\frac') != -1:
        frac_type = r'\frac'
    if sympy.find(r'\dfrac') != -1:
        frac_type = r'\dfrac'
    if sympy.find(r'\tfrac') != -1:
        frac_type = r'\tfrac'
    sympy = sympy.replace(r'\dfrac', r'\frac')
    sympy = sympy.replace(r'\tfrac', r'\frac')
    # Translate Transpose
    sympy = sympy.replace(r'\mathrm{T}', 'T', -1)
    # Translate Derivative
    sympy = sympy.replace(r'\mathrm{d}', 'd', -1).replace(r'{\rm d}', 'd', -1)
    # Translate Matrix
    sympy = sympy.replace(r'\left[\begin{matrix}', r'\begin{bmatrix}', -1).replace(r'\end{matrix}\right]', r'\end{bmatrix}', -1)
    # Translate Permutation
    sympy = re.sub(r"\(([a-zA-Z0-9+\-*/\\ ]+?)\)_{([a-zA-Z0-9+\-*/\\ ]+?)}", r"\\frac{(\1)!}{((\1)-(\2))!}", sympy)
    # Remove \displaystyle
    sympy = sympy.replace(r'\displaystyle', ' ', -1)
    # Remove \quad
    sympy = sympy.replace(r'\quad', ' ', -1).replace(r'\qquad', ' ', -1).replace(r'~', ' ', -1).replace(r'\,', ' ', -1)
    # Remove $
    sympy = sympy.replace(r'$', ' ', -1)

    # variable values
    global VARIABLE_VALUES
    if len(variable_values) > 0:
        VARIABLE_VALUES = variable_values
    else:
        VARIABLE_VALUES = {}

    # setup listener
    matherror = MathErrorListener(sympy)

    # stream input
    stream = InputStream(sympy)
    lex = PSLexer(stream)
    lex.removeErrorListeners()
    lex.addErrorListener(matherror)

    tokens = CommonTokenStream(lex)
    parser = PSParser(tokens)

    # remove default console error listener
    parser.removeErrorListeners()
    parser.addErrorListener(matherror)

    # process the input
    return_data = None
    math = parser.math()

    # if a list
    if math.relation_list():
        return_data = []

        # go over list items
        relation_list = math.relation_list().relation_list_content()
        for list_item in relation_list.relation():
            expr = convert_relation(list_item)
            return_data.append(expr)

    # if not, do default
    else:
        relation = math.relation()
        return_data = convert_relation(relation)

    return return_data


def math_answer_cleaning(answer, dataset_name):
    """
    remove irrelevant strings and unify the answer format before checking whether the answers are equal
    """
    def _is_completely_wrapped_by_text(input_string):
        pattern = r'^\\text{(.*)}$'
        match = re.match(pattern, input_string)
        if match:
            ## input_string is completely wrapped by \text{}
            extracted_content = match.group(1)
            extracted_content = extracted_content.replace("(", "").replace(")", "").replace(",", "")
            return extracted_content
        else:
            return None

    ## remove irrelevant \\text and space
    extracted_content = _is_completely_wrapped_by_text(answer)
    answer = extracted_content if extracted_content else answer
    
    ## e.g., convert 5,\!460 into 5460; convert 14{,}916 into 14916 convert \$4 into 4
    answer = answer.replace(",\!", "").replace("{,}", "").replace("\$", "")
    ## e.g., convert \dfrac{3}{2} into frac{3}{2}
    answer = answer.replace("dfrac{", "frac{").replace("tfrac{", "frac{")
    ## e.g., convert 121^\circ into 121
    answer = answer.replace("^\circ", "")
    answer = answer.replace("^{\circ}", "")
    ## remove \quad
    answer = answer.replace("\quad", "")
    ## remove space
    answer = answer.replace(" ", "")
    ## remove \n
    answer = answer.replace("\n", "").replace("\\n", "")
    ## e.g., convert 3.54\times10^{10} into 3.54e10
    answer = re.sub(r'([+-]?\d*\.?\d+)[\\]times10\^{([+-]?\d+)}', r'\1e\2', answer)
    ## e.g., convert 3.54\times10^10 into 3.54e10
    answer = re.sub(r'([+-]?\d*\.?\d+)[\\]times10\^([+-]?\d+)', r'\1e\2', answer)
    ## e.g., convert 558\,\text{nm} into 558
    answer = re.sub(r'\\,\\text\{.*?\}', '', answer)
    ## e.g., convert 558\text{nm} into 558
    answer = re.sub(r'\\text\{.*?\}', '', answer)
    ## e.g., convert 2^{10} into 2^10
    answer = re.sub(r'(\d+)\^{(\d+)}', r'\1^\2', answer)
    ## lowercase
    answer = answer.lower()

    if dataset_name == "collegemath":
        ## convert 558\mathrm{ft} into 558
        answer = re.sub(r'\\mathrm\{.*?\}', '', answer)
        ## clean noisy answer
        answer = re.sub(r'\$\([^)]*\)', '', answer)
        if answer.endswith("-"):
            answer = answer[:-1]
        if answer.endswith("."):
            answer = answer[:-1]
        if answer.endswith("hours"):
            answer = answer[:-len("hours")]
        ## extract final answer after '=' or ':'
        if "=" in answer:
            answer = answer.split("=", 1)[1]
        if ":" in answer:
            answer = answer.split(":", 1)[1]
        ## \emptyset and \oslash both reprsent empty set in latex
        answer = answer.replace("\\emptyset", "\\oslash")
    if dataset_name == "gsm8k":
        # Example: 5,600 -> 5600
        answer = answer.replace(',', '')
    if dataset_name == "gaokao2023en":
        unit_strings = ['students', 'dollars', 'boxes', 'feet', 'kilometers', 'meters', 'degreesontheBreadusscale', '$', 'a.m.', 'am', 'minutes']
        for unit in unit_strings:
            answer = answer.replace(unit, "")

    return answer


def extract_final_answer(output):
    pattern_re = re.compile(r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}", re.DOTALL)
    all_matches = pattern_re.findall(output)

    if len(all_matches) >= 1:
        extracted_answer = all_matches[-1]
    else:
        extracted_answer = None
    
    return extracted_answer, all_matches


def round_number(answer):
    def _is_float(string):
        try:
            float(string)
            return True
        except:
            return False

    if _is_float(answer) and float(answer) < 1:
        ## to consider the case like 5.56e-10 (convert 5.56e-10 into 5.6e-10)
        ## still return a string type
        return f"{float(answer):.2g}"
    
    return answer


def choice_answer_clean(pred: str):
    pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":")
    # Clean the answer based on the dataset
    tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper())
    if tmp:
        pred = tmp
    else:
        pred = [pred.strip().strip(".")]
    pred = pred[-1]
    # Remove the period at the end, again!
    pred = pred.rstrip(".").rstrip("/")
    return pred


def parse_digits(num):
    num = regex.sub(",", "", str(num))
    try:
        return float(num)
    except:
        if num.endswith("%"):
            num = num[:-1]
            if num.endswith("\\"):
                num = num[:-1]
            try:
                return float(num) / 100
            except:
                pass
    return None


def is_digit(num):
    # paired with parse_digits
    return parse_digits(num) is not None


def str_to_pmatrix(input_str):
    input_str = input_str.strip()
    matrix_str = re.findall(r"\{.*,.*\}", input_str)
    pmatrix_list = []

    for m in matrix_str:
        m = m.strip("{}")
        pmatrix = r"\begin{pmatrix}" + m.replace(",", "\\") + r"\end{pmatrix}"
        pmatrix_list.append(pmatrix)

    return ", ".join(pmatrix_list)


def math_equal(
    prediction: Union[bool, float, str],
    reference: Union[float, str],
    include_percentage: bool = True,
    is_close: bool = True,
    timeout: bool = False,
) -> bool:
    """
    Exact match of math if and only if:
    1. numerical equal: both can convert to float and are equal
    2. symbolic equal: both can convert to sympy expression and are equal
    """
    if prediction is None or reference is None:
        return False
    if str(prediction.strip().lower()) == str(reference.strip().lower()):
        return True
    if (
        reference in ["A", "B", "C", "D", "E"]
        and choice_answer_clean(prediction) == reference
    ):
        return True

    # fraction equal
    if fraction_equal(prediction, reference):
        return True

    try:  # numerical equal
        if round_number(prediction) == round_number(reference):
            return True
        if is_digit(prediction) and is_digit(reference):
            prediction = parse_digits(prediction)
            reference = parse_digits(reference)
            # number questions
            if include_percentage:
                gt_result = [reference / 100, reference, reference * 100]
            else:
                gt_result = [reference]
            for item in gt_result:
                try:
                    if is_close:
                        if numeric_equal(prediction, item):
                            return True
                    else:
                        if item == prediction:
                            return True
                except Exception:
                    continue
            return False
    except:
        pass

    if not prediction and prediction not in [0, False]:
        return False

    # symbolic equal
    reference = str(reference).strip()
    prediction = str(prediction).strip()

    ## pmatrix (amps)
    if "pmatrix" in prediction and not "pmatrix" in reference:
        reference = str_to_pmatrix(reference)

    ## deal with [], (), {}
    pred_str, ref_str = prediction, reference
    if (
        prediction.startswith("[")
        and prediction.endswith("]")
        and not reference.startswith("(")
    ) or (
        prediction.startswith("(")
        and prediction.endswith(")")
        and not reference.startswith("[")
    ):
        pred_str = pred_str.strip("[]()")
        ref_str = ref_str.strip("[]()")
    for s in ["{", "}", "(", ")"]:
        ref_str = ref_str.replace(s, "")
        pred_str = pred_str.replace(s, "")
    if pred_str.lower() == ref_str.lower():
        return True

    ## [a, b] vs. [c, d], return a==c and b==d
    if (
        regex.match(r"(\(|\[).+(\)|\])", prediction) is not None
        and regex.match(r"(\(|\[).+(\)|\])", reference) is not None
    ):
        pred_parts = prediction[1:-1].split(",")
        ref_parts = reference[1:-1].split(",")
        if len(pred_parts) == len(ref_parts):
            if all(
                [
                    math_equal(
                        pred_parts[i], ref_parts[i], include_percentage, is_close
                    )
                    for i in range(len(pred_parts))
                ]
            ):
                return True
    if (
        (
            prediction.startswith("\\begin{pmatrix}")
            or prediction.startswith("\\begin{bmatrix}")
        )
        and (
            prediction.endswith("\\end{pmatrix}")
            or prediction.endswith("\\end{bmatrix}")
        )
        and (
            reference.startswith("\\begin{pmatrix}")
            or reference.startswith("\\begin{bmatrix}")
        )
        and (
            reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}")
        )
    ):
        pred_lines = [
            line.strip()
            for line in prediction[
                len("\\begin{pmatrix}") : -len("\\end{pmatrix}")
            ].split("\\\\")
            if line.strip()
        ]
        ref_lines = [
            line.strip()
            for line in reference[
                len("\\begin{pmatrix}") : -len("\\end{pmatrix}")
            ].split("\\\\")
            if line.strip()
        ]
        matched = True
        if len(pred_lines) == len(ref_lines):
            for pred_line, ref_line in zip(pred_lines, ref_lines):
                pred_parts = pred_line.split("&")
                ref_parts = ref_line.split("&")
                if len(pred_parts) == len(ref_parts):
                    if not all(
                        [
                            math_equal(
                                pred_parts[i],
                                ref_parts[i],
                                include_percentage,
                                is_close,
                            )
                            for i in range(len(pred_parts))
                        ]
                    ):
                        matched = False
                        break
                else:
                    matched = False
                if not matched:
                    break
        else:
            matched = False
        if matched:
            return True

    if prediction.count("=") == 1 and reference.count("=") == 1:
        pred = prediction.split("=")
        pred = f"{pred[0].strip()} - ({pred[1].strip()})"
        ref = reference.split("=")
        ref = f"{ref[0].strip()} - ({ref[1].strip()})"
        if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref):
            return True
    elif (
        prediction.count("=") == 1
        and len(prediction.split("=")[0].strip()) <= 2
        and "=" not in reference
    ):
        if math_equal(
            prediction.split("=")[1], reference, include_percentage, is_close
        ):
            return True
    elif (
        reference.count("=") == 1
        and len(reference.split("=")[0].strip()) <= 2
        and "=" not in prediction
    ):
        if math_equal(
            prediction, reference.split("=")[1], include_percentage, is_close
        ):
            return True

    # symbolic equal with sympy
    if timeout:
        if call_with_timeout(symbolic_equal_process, prediction, reference):
            return True
    else:
        if symbolic_equal(prediction, reference):
            return True

    return False


def numeric_equal(prediction: float, reference: float):
    # Note that relative tolerance has significant impact
    # on the result of the synthesized GSM-Hard dataset
    # if reference.is_integer():
    #     return isclose(reference, round(prediction), abs_tol=1e-4)
    # else:
    # prediction = round(prediction, len(str(reference).split(".")[-1]))
    return isclose(reference, prediction, rel_tol=1e-4)


def fraction_equal(prediction, reference):
    def _calculate_numbers(input_string):
        try:
            result = eval(input_string)
            return result
        except:
            return None
    
    reference = re.sub(r'\\frac{(.*?)}{(.*?)}', r'(\1/\2)', reference)
    prediction = re.sub(r'\\frac{(.*?)}{(.*?)}', r'(\1/\2)', prediction)

    if reference == prediction:
        return True

    reference = _calculate_numbers(reference)
    prediction = _calculate_numbers(prediction)

    if reference and reference == prediction:
        return True
    
    return False

def symbolic_equal(a, b):
    def _parse(s):
        for f in [parse_latex, parse_expr, latex2sympy]:
            try:
                return f(s.replace("\\\\", "\\"))
            except:
                try:
                    return f(s)
                except:
                    pass
        return s

    a = _parse(a)
    b = _parse(b)

    # direct equal
    try:
        if str(a) == str(b) or a == b:
            return True
    except:
        pass

    # simplify equal
    try:
        if a.equals(b) or simplify(a - b) == 0:
            return True
    except:
        pass

    # equation equal
    try:
        if (abs(a.lhs - a.rhs)).equals(abs(b.lhs - b.rhs)):
            return True
    except:
        pass

    try:
        if numeric_equal(float(N(a)), float(N(b))):
            return True
    except:
        pass

    # matrix
    try:
        # if a and b are matrix
        if a.shape == b.shape:
            _a = a.applyfunc(lambda x: round(x, 3))
            _b = b.applyfunc(lambda x: round(x, 3))
            if _a.equals(_b):
                return True
    except:
        pass

    return False


def symbolic_equal_process(a, b, output_queue):
    result = symbolic_equal(a, b)
    output_queue.put(result)


def math_equal_process(prediction, reference, output_queue):
    result = math_equal(prediction, reference, timeout=True)
    output_queue.put(result)


def call_with_timeout(func, *args, timeout=1, **kwargs):
    output_queue = multiprocessing.Queue()
    process_args = args + (output_queue,)
    process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs)
    process.start()
    process.join(timeout)

    if process.is_alive():
        process.terminate()
        process.join()
        return False

    return output_queue.get()


def check_correctness_of_multiple_answer_cases(prediction, reference, all_matches):

    if prediction.replace(",", "").replace("$", "") == reference.replace(",", "").replace("$", ""):
        return True
    
    if not prediction.split("=")[-1] == reference.split("=")[-1].replace("$", ""):
        return False

    if "," in reference or "or" in reference or "and" in reference:
        ## there are multiple answers
        if len(all_matches) <= 1:
            return False

        prediction1 = prediction.split("=")[-1]
        prediction2 = all_matches[-2].split("=")[-1]
        reference = reference.replace("$", "")
        if "or" in reference:
            gold_list = reference.split("or", 1)
        elif "and" in reference:
            gold_list = reference.split("and", 1)
        else:
            gold_list = reference.split(",", 1)
        
        reference1 = gold_list[-1].split("=")[-1]
        reference2 = gold_list[-2].split("=")[-1]
        
        if math_equal(prediction1, reference1) and math_equal(prediction2, reference2):
            return True
        elif math_equal(prediction2, reference1) and math_equal(prediction1, reference2):
            return True

        return False
        
    else:
        return True


def is_equal(model_output, reference, dataset_name):
    
    extracted_model_answer, all_matches = extract_final_answer(model_output)
    if extracted_model_answer is None or reference is None:
        return False

    extracted_model_answer = math_answer_cleaning(extracted_model_answer, dataset_name)
    reference = math_answer_cleaning(reference, dataset_name)

    # if math_equal(prediction, reference, timeout=True):
    if call_with_timeout(math_equal_process, extracted_model_answer, reference):
        return True
    
    if dataset_name == "collegemath":
        return check_correctness_of_multiple_answer_cases(extracted_model_answer, reference, all_matches)

    return False