CodeLATS / generators /generator_utils.py
Ron
initial commit
41d1bc5
from generators.model import ModelBase, Message
import random
import streamlit as st
from typing import Union, List, Optional, Callable
def generic_generate_func_impl(
func_sig: str,
model: ModelBase,
strategy: str,
prev_func_impl,
feedback,
self_reflection,
num_comps,
temperature,
reflexion_chat_instruction: str,
reflexion_few_shot: str,
simple_chat_instruction: str,
reflexion_completion_instruction: str,
simple_completion_instruction: str,
code_block_instruction: str,
parse_code_block: Callable[[str], str],
add_code_block: Callable[[str], str]
) -> Union[str, List[str]]:
if strategy != "reflexion" and strategy != "simple":
raise ValueError(
f"Invalid strategy: given `{strategy}` but expected one of `reflexion` or `simple`")
if strategy == "reflexion" and (prev_func_impl is None or feedback is None or self_reflection is None):
raise ValueError(
f"Invalid arguments: given `strategy=reflexion` but `prev_func_impl`, `feedback`, or `self_reflection` is None")
if model.is_chat:
if strategy == "reflexion":
message = f"{reflexion_few_shot}\n[previous impl]:\n{add_code_block(prev_func_impl)}\n\n[unit test results from previous impl]:\n{feedback}\n\n[reflection on previous impl]:\n{self_reflection}\n\n[improved impl]:\n{func_sig}"
prompt = f"{reflexion_chat_instruction}\n{code_block_instruction}"
# func_bodies is a really bad name, as it can also be just 1 string
print_messages(prompt, message)
messages = [
Message(
role="system",
content=prompt,
),
Message(
role="user", # TODO: check this
content=reflexion_few_shot,
),
Message(
role="assistant",
content=add_code_block(prev_func_impl),
),
Message(
role="user",
content=f"[unit test results from previous impl]:\n{feedback}\n\n[reflection on previous impl]:",
),
Message(
role="assistant",
content=self_reflection,
),
Message(
role="user",
content=f"[improved impl]:\n{func_sig}",
),
]
func_bodies = model.generate_chat(messages=messages, num_comps=num_comps, temperature=temperature)
else:
system_prompt = f"{simple_chat_instruction}\n{code_block_instruction}"
print_messages(system_prompt, func_sig)
messages = [
Message(
role="system",
content=f"{simple_chat_instruction}\n{code_block_instruction}",
),
Message(
role="user",
content=func_sig,
),
]
func_bodies = model.generate_chat(messages=messages, num_comps=num_comps, temperature=temperature)
else:
if strategy == "reflexion":
prompt = f"{reflexion_completion_instruction}\n{add_code_block(prev_func_impl)}\n\nunit tests:\n{feedback}\n\nhint:\n{self_reflection}\n\n# improved implementation\n{func_sig}\n{code_block_instruction}"
func_bodies = model.generate(
prompt, num_comps=num_comps, temperature=temperature)
else:
prompt = f"{simple_completion_instruction}\n{func_sig}\n{code_block_instruction}"
func_bodies = model.generate(
prompt, num_comps=num_comps, temperature=temperature)
if num_comps == 1:
assert isinstance(func_bodies, str)
func_body_str = parse_code_block(func_bodies)
print_generated_func_body(func_body_str)
return func_body_str
else:
func_bodies = [parse_code_block(func_body) for func_body in func_bodies]
print_generated_func_body("\n\n".join(func_bodies))
return func_bodies
def generate_with_accumulated_context(
func_sig: str,
model: ModelBase,
strategy: str,
prev_func_impl,
accumulated_feedback,
accumulated_reflection,
num_comps,
temperature,
reflexion_chat_instruction: str,
reflexion_few_shot: str,
simple_chat_instruction: str,
reflexion_completion_instruction: str,
simple_completion_instruction: str,
code_block_instruction: str,
parse_code_block: Callable[[str], str],
add_code_block: Callable[[str], str]
) -> Union[str, List[str]]:
# Ensure that the strategy is valid
if strategy != "reflexion" and strategy != "simple":
raise ValueError(
f"Invalid strategy: given `{strategy}` but expected one of `reflexion` or `simple`")
if strategy == "reflexion" and (prev_func_impl is None or accumulated_feedback is None or accumulated_reflection is None):
raise ValueError(
f"Invalid arguments: given `strategy=reflexion` but `prev_func_impl`, `feedback`, or `self_reflection` is None")
# Build the accumulated context from the provided feedback and reflections
accumulated_context = "\n\n".join(
[f"[previous impl {i+1}]:\n{add_code_block(impl)}\n[unit test results from previous impl {i+1}]:\n{feedback}\n[reflection on previous impl {i+1}]:\n{reflection}"
for i, (impl, feedback, reflection) in enumerate(zip(prev_func_impl, accumulated_feedback, accumulated_reflection))]
)
if model.is_chat:
if strategy == "reflexion":
# Constructing the message using a loop for accumulated context
messages = [
Message(role="system", content=f"{reflexion_chat_instruction}\n{code_block_instruction}"),
Message(role="user", content=reflexion_few_shot)
]
for impl, feedback, reflection in zip(prev_func_impl, accumulated_feedback, accumulated_reflection):
messages.append(Message(role="assistant", content=add_code_block(impl)))
messages.append(Message(role="user", content=f"[unit test results from previous impl]:\n{feedback}\n\n[reflection on previous impl]:\n{reflection}"))
messages.append(Message(role="user", content=f"[improved impl]:\n{func_sig}"))
prompt = "\n".join([message.content for message in messages])
message = (f"{reflexion_few_shot}\n{accumulated_context}\n\n[improved impl]:\n{func_sig}")
print_messages(prompt, message)
func_bodies = model.generate_chat(messages=messages, num_comps=num_comps, temperature=temperature)
else:
system_prompt = f"{simple_chat_instruction}\n{code_block_instruction}"
print_messages(system_prompt, func_sig)
messages = [
Message(role="system", content=f"{simple_chat_instruction}\n{code_block_instruction}"),
Message(role="user", content=func_sig)
]
func_bodies = model.generate_chat(messages=messages, num_comps=num_comps, temperature=temperature)
else:
if strategy == "reflexion":
prompt = f"{reflexion_completion_instruction}\n{accumulated_context}\n\n# improved implementation\n{func_sig}\n{code_block_instruction}"
func_bodies = model.generate(prompt, num_comps=num_comps, temperature=temperature)
print_messages(prompt, "")
else:
prompt = f"{simple_completion_instruction}\n{func_sig}\n{code_block_instruction}"
func_bodies = model.generate(prompt, num_comps=num_comps, temperature=temperature)
print_messages(prompt, "")
if num_comps == 1:
assert isinstance(func_bodies, str)
func_body_str = parse_code_block(func_bodies)
print_generated_func_body(func_body_str)
return func_body_str
else:
func_bodies = [parse_code_block(func_body) for func_body in func_bodies]
print_generated_func_body("\n\n".join(func_bodies))
return func_bodies
def generic_generate_internal_tests(
func_sig: str,
model: ModelBase,
max_num_tests: int,
test_generation_few_shot: str,
test_generation_chat_instruction: str,
test_generation_completion_instruction: str,
parse_tests: Callable[[str], List[str]],
is_syntax_valid: Callable[[str], bool],
is_react: bool = False
) -> List[str]:
"""Generates tests for a function."""
if model.is_chat:
if is_react:
messages = [
Message(
role="system",
content=test_generation_chat_instruction,
),
Message(
role="user",
content=f"{test_generation_few_shot}\n\n[func signature]:\n{func_sig}\n\n[think]:"
)
]
output = model.generate_chat(messages=messages, max_tokens=1024)
print(f'React test generation output: {output}')
else:
messages = [
Message(
role="system",
content=test_generation_chat_instruction,
),
Message(
role="user",
content=f"{test_generation_few_shot}\n\n[func signature]:\n{func_sig}\n\n[unit tests]:",
)
]
output = model.generate_chat(messages=messages, max_tokens=1024)
else:
prompt = f'{test_generation_completion_instruction}\n\nfunc signature:\n{func_sig}\nunit tests:'
output = model.generate(prompt, max_tokens=1024)
all_tests = parse_tests(output) # type: ignore
valid_tests = [test for test in all_tests if is_syntax_valid(test)]
# print(valid_tests)
return (valid_tests)
def generic_generate_self_reflection(
func: str,
feedback: str,
model: ModelBase,
self_reflection_chat_instruction: str,
self_reflection_completion_instruction: str,
add_code_block: Callable[[str], str],
self_reflection_few_shot: Optional[str] = None,
) -> str:
if model.is_chat:
if self_reflection_few_shot is not None:
messages = [
Message(
role="system",
content=self_reflection_chat_instruction,
),
Message(
role="user",
content=f'{self_reflection_few_shot}\n\n[function impl]:\n{add_code_block(func)}\n\n[unit test results]:\n{feedback}\n\n[self-reflection]:',
)
]
reflection = model.generate_chat(messages=messages)
print(f'|Self reflection output|: {reflection}')
else:
messages = [
Message(
role="system",
content=self_reflection_chat_instruction,
),
Message(
role="user",
content=f'[function impl]:\n{add_code_block(func)}\n\n[unit test results]:\n{feedback}\n\n[self-reflection]:',
)
]
reflection = model.generate_chat(messages=messages)
else:
reflection = model.generate(
f'{self_reflection_completion_instruction}\n{add_code_block(func)}\n\n{feedback}\n\nExplanation:')
return reflection # type: ignore
def sample_n_random(items: List[str], n: int) -> List[str]:
"""Sample min(n, len(items)) random items from a list"""
assert n >= 0
if n >= len(items):
return items
return random.sample(items, n)
def print_messages(system_message_text: str, user_message_text: str) -> None:
print(f"""{system_message_text}""")
print(f"""{user_message_text} \n""")
def print_generated_func_body(func_body_str: str) -> None:
print(f"""|GENERATED FUNCTION BODY| \n
```python\n{func_body_str} \n
""")