chunker / app.py
kevinlu1248's picture
Final brush up
a65612d
raw
history blame
7.48 kB
from __future__ import annotations
import re
import requests
from dataclasses import dataclass
import gradio as gr
from tree_sitter import Tree, Node
from tree_sitter_languages import get_parser
def non_whitespace_len(s: str) -> int: # new len function
return len(re.sub("\s", "", s))
def get_line_number(index: int, source_code: str) -> int:
total_chars = 0
for line_number, line in enumerate(source_code.splitlines(keepends=True), start=1):
total_chars += len(line)
if total_chars > index:
return line_number - 1
return line_number
@dataclass
class Span:
# Represents a slice of a string
start: int = 0
end: int = 0
def __post_init__(self):
# If end is None, set it to start
if self.end is None:
self.end = self.start
def extract(self, s: str) -> str:
# Grab the corresponding substring of string s by bytes
return s[self.start: self.end]
def extract_lines(self, s: str) -> str:
# Grab the corresponding substring of string s by lines
return "\n".join(s.splitlines()[self.start:self.end])
def __add__(self, other: Span | int) -> Span:
# e.g. Span(1, 2) + Span(2, 4) = Span(1, 4) (concatenation)
# There are no safety checks: Span(a, b) + Span(c, d) = Span(a, d)
# and there are no requirements for b = c.
if isinstance(other, int):
return Span(self.start + other, self.end + other)
elif isinstance(other, Span):
return Span(self.start, other.end)
else:
raise NotImplementedError()
def __len__(self) -> int:
# i.e. Span(a, b) = b - a
return self.end - self.start
def chunk_tree(
tree: Tree,
source_code: bytes,
MAX_CHARS=512 * 3,
coalesce=50 # Any chunk less than 50 characters long gets coalesced with the next chunk
) -> list[Span]:
# 1. Recursively form chunks based on the last post (https://docs.sweep.dev/blogs/chunking-2m-files)
def chunk_node(node: Node) -> list[Span]:
chunks: list[Span] = []
current_chunk: Span = Span(node.start_byte, node.start_byte)
node_children = node.children
for child in node_children:
if child.end_byte - child.start_byte > MAX_CHARS:
chunks.append(current_chunk)
current_chunk = Span(child.end_byte, child.end_byte)
chunks.extend(chunk_node(child))
elif child.end_byte - child.start_byte + len(current_chunk) > MAX_CHARS:
chunks.append(current_chunk)
current_chunk = Span(child.start_byte, child.end_byte)
else:
current_chunk += Span(child.start_byte, child.end_byte)
chunks.append(current_chunk)
return chunks
chunks = chunk_node(tree.root_node)
# 2. Filling in the gaps
for prev, curr in zip(chunks[:-1], chunks[1:]):
prev.end = curr.start
curr.start = tree.root_node.end_byte
# 3. Combining small chunks with bigger ones
new_chunks = []
current_chunk = Span(0, 0)
for chunk in chunks:
current_chunk += chunk
if non_whitespace_len(current_chunk.extract(source_code.decode("utf-8"))) > coalesce \
and "\n" in current_chunk.extract(source_code.decode("utf-8")):
new_chunks.append(current_chunk)
current_chunk = Span(chunk.end, chunk.end)
if len(current_chunk) > 0:
new_chunks.append(current_chunk)
# 4. Changing line numbers
line_chunks = [
Span(
get_line_number(chunk.start, source_code),
get_line_number(chunk.end, source_code)
)
for chunk in new_chunks
]
# 5. Eliminating empty chunks
line_chunks = [chunk for chunk in line_chunks if len(chunk) > 0]
return line_chunks
css = """
.code_container {
}
"""
def chunk_code(
code: str,
language: str,
MAX_CHARS: int,
coalesce: int
):
try:
parser = get_parser(language)
tree = parser.parse(code.encode("utf-8"))
chunks = chunk_tree(tree, code.encode("utf-8"), MAX_CHARS=MAX_CHARS, coalesce=coalesce)
chunks = [chunk.extract_lines(code) for chunk in chunks]
return "\n\n====================\n\n".join(chunks)
except Exception as e:
return str(e)
examples_dict = {
"Python: Sweep's GiHub Actions log handler": ("https://raw.githubusercontent.com/sweepai/sweep/b267b613d4c706eaf959fe6789f11e9a856521d1/sweepai/handlers/on_check_suite.py", "python"),
"Typescript: LlamaIndex TS's BaseIndex abstract base class": ("https://raw.githubusercontent.com/run-llama/LlamaIndexTS/bfab1d407b7b390d76b3d7a1a1df0928e9f9ae11/packages/core/src/indices/BaseIndex.ts", "typescript"),
"Rust: Ruff's autofix code modification algorithm": ("https://raw.githubusercontent.com/astral-sh/ruff/main/crates/ruff/src/autofix/codemods.rs", "rust"),
"Go: Infisical's CLI's config manager": ("https://raw.githubusercontent.com/Infisical/infisical/de7bd27b4b48847c9ca7cd12d208225b06f170fe/cli/packages/util/config.go", "go")
}
default_key = "Python: GiHub Actions log handler"
default_url, default_language = examples_dict[default_key]
default_code = requests.get(default_url).text
with gr.Blocks(css=css) as demo:
gr.Markdown("## Code Chunking Demo")
gr.Markdown("Start typing below and the chunked output will automatically show up. Checkout how this algorithm works at https://docs.sweep.dev/blogs/chunking-2m-files and https://docs.sweep.dev/blogs/chunking-improvements or play with the notebook at https://github.com/sweepai/sweep/blob/main/notebooks/chunking.ipynb.")
with gr.Row():
language = gr.Dropdown(['python', 'javascript', 'typescript', 'rust', 'go', 'ruby', 'r', 'html', 'css', 'shell'], label="Language", value=default_language)
max_chars = gr.Slider(10, 3000, 1500, label="Max Characters", step=10)
coalesce = gr.Slider(0, 300, 100, label="Coalesce", step=10)
examples = gr.Dropdown(list(examples_dict.keys()), label="Examples", value=default_key, interactive=True)
with gr.Row():
input_code = gr.Code(label="Input Code", language=language.value, lines=60, elem_classes="code_container", value=default_code)
output_code = gr.Code(label="Chunked Code", language=language.value, lines=60, value=chunk_code(default_code, language.value, max_chars.value, coalesce.value))
def update_examples(examples):
url, language = examples_dict[examples]
code = requests.get(url).text
return gr.Code.update(language=language, value=code), gr.Code.update(language=language, value=chunk_code(code, language, max_chars.value, coalesce.value)), language
examples.change(fn=update_examples, inputs=examples, outputs=[input_code, output_code, language])
def update_language(language):
return gr.Code.update(language=language), gr.Code.update(language=language)
language.change(fn=update_language, inputs=language, outputs=[input_code, output_code]) \
.then(fn=chunk_code, inputs=[input_code, language, max_chars, coalesce], outputs=output_code)
max_chars.change(fn=chunk_code, inputs=[input_code, language, max_chars, coalesce], outputs=output_code)
coalesce.change(fn=chunk_code, inputs=[input_code, language, max_chars, coalesce], outputs=output_code)
input_code.change(fn=chunk_code, inputs=[input_code, language, max_chars, coalesce], outputs=output_code)
demo.launch()