Spaces:
Runtime error
Runtime error
File size: 6,224 Bytes
d596fb5 f749736 d596fb5 f41d2af 5bcf07b d596fb5 c07d625 d596fb5 f749736 d596fb5 f749736 d596fb5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
from __future__ import annotations
import re
import requests
from dataclasses import dataclass
import gradio as gr
from tree_sitter import Tree, Node
from tree_sitter_languages import get_parser
def non_whitespace_len(s: str) -> int: # new len function
return len(re.sub("\s", "", s))
def get_line_number(index: int, source_code: str) -> int:
total_chars = 0
for line_number, line in enumerate(source_code.splitlines(keepends=True), start=1):
total_chars += len(line)
if total_chars > index:
return line_number - 1
return line_number
@dataclass
class Span:
# Represents a slice of a string
start: int = 0
end: int = 0
def __post_init__(self):
# If end is None, set it to start
if self.end is None:
self.end = self.start
def extract(self, s: str) -> str:
# Grab the corresponding substring of string s by bytes
return s[self.start: self.end]
def extract_lines(self, s: str) -> str:
# Grab the corresponding substring of string s by lines
return "\n".join(s.splitlines()[self.start:self.end])
def __add__(self, other: Span | int) -> Span:
# e.g. Span(1, 2) + Span(2, 4) = Span(1, 4) (concatenation)
# There are no safety checks: Span(a, b) + Span(c, d) = Span(a, d)
# and there are no requirements for b = c.
if isinstance(other, int):
return Span(self.start + other, self.end + other)
elif isinstance(other, Span):
return Span(self.start, other.end)
else:
raise NotImplementedError()
def __len__(self) -> int:
# i.e. Span(a, b) = b - a
return self.end - self.start
def chunk_tree(
tree: Tree,
source_code: bytes,
MAX_CHARS=512 * 3,
coalesce=50 # Any chunk less than 50 characters long gets coalesced with the next chunk
) -> list[Span]:
# 1. Recursively form chunks based on the last post (https://docs.sweep.dev/blogs/chunking-2m-files)
def chunk_node(node: Node) -> list[Span]:
chunks: list[Span] = []
current_chunk: Span = Span(node.start_byte, node.start_byte)
node_children = node.children
for child in node_children:
if child.end_byte - child.start_byte > MAX_CHARS:
chunks.append(current_chunk)
current_chunk = Span(child.end_byte, child.end_byte)
chunks.extend(chunk_node(child))
elif child.end_byte - child.start_byte + len(current_chunk) > MAX_CHARS:
chunks.append(current_chunk)
current_chunk = Span(child.start_byte, child.end_byte)
else:
current_chunk += Span(child.start_byte, child.end_byte)
chunks.append(current_chunk)
return chunks
chunks = chunk_node(tree.root_node)
# 2. Filling in the gaps
for prev, curr in zip(chunks[:-1], chunks[1:]):
prev.end = curr.start
curr.start = tree.root_node.end_byte
# 3. Combining small chunks with bigger ones
new_chunks = []
current_chunk = Span(0, 0)
for chunk in chunks:
current_chunk += chunk
if non_whitespace_len(current_chunk.extract(source_code.decode("utf-8"))) > coalesce \
and "\n" in current_chunk.extract(source_code.decode("utf-8")):
new_chunks.append(current_chunk)
current_chunk = Span(chunk.end, chunk.end)
if len(current_chunk) > 0:
new_chunks.append(current_chunk)
# 4. Changing line numbers
line_chunks = [
Span(
get_line_number(chunk.start, source_code),
get_line_number(chunk.end, source_code)
)
for chunk in new_chunks
]
# 5. Eliminating empty chunks
line_chunks = [chunk for chunk in line_chunks if len(chunk) > 0]
return line_chunks
css = """
.code_container {
}
"""
def chunk_code(
code: str,
language: str,
MAX_CHARS: int,
coalesce: int
):
try:
parser = get_parser(language)
tree = parser.parse(code.encode("utf-8"))
chunks = chunk_tree(tree, code.encode("utf-8"), MAX_CHARS=MAX_CHARS, coalesce=coalesce)
chunks = [chunk.extract_lines(code) for chunk in chunks]
return "\n\n====================\n\n".join(chunks)
except Exception as e:
return str(e)
with gr.Blocks(css=css) as demo:
gr.Markdown("## Code Chunking Demo")
gr.Markdown("Start typing below and the chunked output will automatically show up. Checkout how this algorithm works at https://docs.sweep.dev/blogs/chunking-2m-files and https://docs.sweep.dev/blogs/chunking-improvements. We also have interactive notebooks at https://github.com/sweepai/sweep/blob/main/notebooks/chunking.ipynb.")
default_file = "https://raw.githubusercontent.com/sweepai/sweep/b267b613d4c706eaf959fe6789f11e9a856521d1/sweepai/handlers/on_check_suite.py"
default_code = requests.get(default_file).text
with gr.Row():
language = gr.Dropdown(["python", "javascript", "go", "ruby", "java", "php", "c", "cpp", "rust", "haskell"], label="Language", value="python")
max_chars = gr.Slider(1, 3000, 1500, label="Max Characters", step=10)
coalesce = gr.Slider(0, 300, 100, label="Coalesce", step=10)
with gr.Row():
inp = gr.Code(placeholder="Enter the code here", label="Code to Chunk", language=language.value, lines=60, elem_classes="code_container", value=default_code)
out = gr.Code(label="Chunked Code", language=language.value, lines=60, value=chunk_code(default_code, language.value, max_chars.value, coalesce.value))
def update_language(inp, language, max_chars, coalesce):
return (
gr.update(language=language),
gr.update(language=language, value=chunk_code(inp.value, language, max_chars, coalesce))
)
language.change(fn=update_language, inputs=[inp, language, max_chars, coalesce], outputs=[inp, out])
max_chars.change(fn=chunk_code, inputs=[inp, language, max_chars, coalesce], outputs=out)
coalesce.change(fn=chunk_code, inputs=[inp, language, max_chars, coalesce], outputs=out)
inp.change(fn=chunk_code, inputs=[inp, language, max_chars, coalesce], outputs=out)
demo.launch()
|