from __future__ import annotations

import re
import requests
from dataclasses import dataclass

import gradio as gr
from tree_sitter import Tree, Node
from tree_sitter_languages import get_parser

def non_whitespace_len(s: str) -> int: # new len function
    return len(re.sub("\s", "", s))

def get_line_number(index: int, source_code: str) -> int:
    total_chars = 0
    for line_number, line in enumerate(source_code.splitlines(keepends=True), start=1):
        total_chars += len(line)
        if total_chars > index:
            return line_number - 1
    return line_number

@dataclass
class Span:
    # Represents a slice of a string
    start: int = 0
    end: int = 0

    def __post_init__(self):
        # If end is None, set it to start
        if self.end is None:
            self.end = self.start

    def extract(self, s: str) -> str:
        # Grab the corresponding substring of string s by bytes
        return s[self.start: self.end]

    def extract_lines(self, s: str) -> str:
        # Grab the corresponding substring of string s by lines
        return "\n".join(s.splitlines()[self.start:self.end])

    def __add__(self, other: Span | int) -> Span:
        # e.g. Span(1, 2) + Span(2, 4) = Span(1, 4) (concatenation)
        # There are no safety checks: Span(a, b) + Span(c, d) = Span(a, d)
        # and there are no requirements for b = c.
        if isinstance(other, int):
            return Span(self.start + other, self.end + other)
        elif isinstance(other, Span):
            return Span(self.start, other.end)
        else:
            raise NotImplementedError()

    def __len__(self) -> int:
        # i.e. Span(a, b) = b - a
        return self.end - self.start

def chunk_tree(
	tree: Tree,
	source_code: bytes,
	MAX_CHARS=512 * 3,
	coalesce=50 # Any chunk less than 50 characters long gets coalesced with the next chunk
) -> list[Span]:

    # 1. Recursively form chunks based on the last post (https://docs.sweep.dev/blogs/chunking-2m-files)
    def chunk_node(node: Node) -> list[Span]:
        chunks: list[Span] = []
        current_chunk: Span = Span(node.start_byte, node.start_byte)
        node_children = node.children
        for child in node_children:
            if child.end_byte - child.start_byte > MAX_CHARS:
                chunks.append(current_chunk)
                current_chunk = Span(child.end_byte, child.end_byte)
                chunks.extend(chunk_node(child))
            elif child.end_byte - child.start_byte + len(current_chunk) > MAX_CHARS:
                chunks.append(current_chunk)
                current_chunk = Span(child.start_byte, child.end_byte) 
            else:
                current_chunk += Span(child.start_byte, child.end_byte)
        chunks.append(current_chunk)
        return chunks
    chunks = chunk_node(tree.root_node)

    # 2. Filling in the gaps
    for prev, curr in zip(chunks[:-1], chunks[1:]):
        prev.end = curr.start
    curr.start = tree.root_node.end_byte

    # 3. Combining small chunks with bigger ones
    new_chunks = []
    current_chunk = Span(0, 0)
    for chunk in chunks:
        current_chunk += chunk
        if non_whitespace_len(current_chunk.extract(source_code.decode("utf-8"))) > coalesce \
            and "\n" in current_chunk.extract(source_code.decode("utf-8")):
            new_chunks.append(current_chunk)
            current_chunk = Span(chunk.end, chunk.end)
    if len(current_chunk) > 0:
        new_chunks.append(current_chunk)

    # 4. Changing line numbers
    line_chunks = [
        Span(
            get_line_number(chunk.start, source_code),
            get_line_number(chunk.end, source_code)
        )
        for chunk in new_chunks
    ]

    # 5. Eliminating empty chunks
    line_chunks = [chunk for chunk in line_chunks if len(chunk) > 0]

    return line_chunks

css = """
.code_container {
}
"""

def chunk_code(
    code: str,
    language: str,
    MAX_CHARS: int,
    coalesce: int
):
    try:
        parser = get_parser(language)
        tree = parser.parse(code.encode("utf-8"))
        chunks = chunk_tree(tree, code.encode("utf-8"), MAX_CHARS=MAX_CHARS, coalesce=coalesce)
        chunks = [chunk.extract_lines(code) for chunk in chunks]
        return "\n\n====================\n\n".join(chunks)
    except Exception as e:
        return str(e)

with gr.Blocks(css=css) as demo:
    gr.Markdown("Start typing below and the chunked output will automatically show up.")

    default_file = "https://raw.githubusercontent.com/sweepai/sweep/b267b613d4c706eaf959fe6789f11e9a856521d1/sweepai/handlers/on_check_suite.py"
    default_code = requests.get(default_file).text

    with gr.Row():
        language = gr.Dropdown(["python", "javascript", "go", "ruby", "java", "php", "c", "cpp", "rust", "haskell"], label="Language", value="python")
        max_chars = gr.Slider(100, 3000, 1500, label="Max Characters")
        coalesce = gr.Slider(0, 300, 100, label="Coalesce")
    with gr.Row():
        inp = gr.Code(placeholder="Enter the code here", label="Code to Chunk", language=language.value, lines=60, elem_classes="code_container", value=default_code)
        out = gr.Code(label="Chunked Code", language=language.value, lines=60, value=chunk_code(default_code, language.value, max_chars.value, coalesce.value))

    def update_language(inp, language, max_chars, coalesce):
        return (
            gr.update(language=language),
            gr.update(language=language, value=chunk_code(inp.value, language, max_chars, coalesce))
        )

    language.change(fn=update_language, inputs=[inp, language, max_chars, coalesce], outputs=[inp, out])
    max_chars.change(fn=chunk_code, inputs=[inp, language, max_chars, coalesce], outputs=out)
    coalesce.change(fn=chunk_code, inputs=[inp, language, max_chars, coalesce], outputs=out)
    inp.change(fn=chunk_code, inputs=[inp, language, max_chars, coalesce], outputs=out)

demo.launch()