import json import asyncio from typing import List from typing_extensions import TypedDict from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate from langgraph.graph import StateGraph, END from src.utils.api_key_manager import with_api_manager from src.helpers.helper import remove_markdown # Define the Graph State class GraphState(TypedDict): initial_prompt: str plan: str write_steps: List[dict] final_json: str @with_api_manager(temperature=0.0, top_p=1.0) def planning_node(state: GraphState, *, llm) -> GraphState: print("\n---PLANNING---\n") initial_prompt = state['initial_prompt'] plan_template = \ f"""You need to create a structured JSON based on the following instructions: {initial_prompt} Rules: 1. Outline a multi-step plan (one step per line) that will guide the creation of the final JSON. 2. You must create the entire plan yourself without asking others to create it for you. 2. The steps should be as follows: - Each step should be a high-level task or section of the JSON. - Check if breaking down each step into smaller, low-level sub-tasks or sections is required - If yes, ONLY include the sub-steps (one sub-step per line). 3. The plan should be concise and clear, and each step and sub-step should be distinct. 4. The plan should be unformatted and in plain text. DO NOT even use bullet points or new lines. 4. The number of steps should be as less as possible, but still enough to cover ALL sections. 5. If the user request contains any specific details, include them in the plan. 6. DO NOT create the final content, just the plan/outline. 7. DO NOT include any markdown or formatting in the plan.""" chat_template = ChatPromptTemplate.from_messages([ HumanMessagePromptTemplate.from_template("{text}"), ] ) prompt = chat_template.invoke({"text": plan_template}) response = llm.invoke(prompt) plan = response.content.strip() # Store plan text in state state['plan'] = remove_markdown(plan) print(plan) return state @with_api_manager(temperature=0.0, top_p=1.0) def writing_node_sync(state: GraphState, *, llm) -> GraphState: print("\n---WRITING THE JSON---\n") initial_prompt = state['initial_prompt'] plan = state['plan'] plan = plan.strip() # Split the plan by lines plan_lines = plan.split('\n') # Our final partial JSON objects partial_jsons: List[dict] = [] # Return partial JSON. for idx, step_line in enumerate(plan_lines): if len(step_line.strip()) > 0: step_prompt_text = \ f"""You are creating part {idx+1} of the final JSON document. User request: {initial_prompt} Plan step (outline): {step_line.strip()} Rules: 1. You need to write the JSON data for this step. 2. The JSON should be structured and valid. 3. If the user request contains any specific details, include them in the JSON. 4. If the user request contains the format of the JSON, follow it. If not, create a generic JSON as you see fit. 5. Respond ONLY with valid JSON for this step without any markdown or formatting.""" chat_template = ChatPromptTemplate.from_messages([ HumanMessagePromptTemplate.from_template("{text}"), ] ) prompt = chat_template.invoke({"text": step_prompt_text}) response = llm.invoke(prompt) step_result = response.content.strip() # Attempt to parse the partial JSON try: cleaned_result = remove_markdown(step_result) partial_obj = json.loads(cleaned_result) except json.JSONDecodeError: # If the model didn't produce valid JSON, throw an error raise Exception(f"Failed to parse JSON data for step {idx+1}") # print(f"Step {idx+1} JSON:\n{json.dumps(partial_obj, indent=2)}\n") # Add the partial JSON to the list partial_jsons.append(partial_obj) # Save all partial JSON in the state state['write_steps'] = partial_jsons return state @with_api_manager(temperature=0.0, top_p=1.0) async def writing_node_async(state: GraphState, *, llm) -> GraphState: async def get_partial_json(idx: int, step_line: str) -> dict: step_prompt_text = \ f"""You are creating part {idx+1} of the final JSON document. User request: {initial_prompt} Plan step (outline): {step_line.strip()} Rules: 1. You need to write the JSON data for this step. 2. The JSON should be structured and valid. 3. If the user request contains any specific details, include them in the JSON. 4. If the user request contains the format of the JSON, follow it. If not, create a generic JSON as you see fit. 5. Respond ONLY with valid JSON for this step without any markdown or formatting.""" chat_template = ChatPromptTemplate.from_messages([ HumanMessagePromptTemplate.from_template("{text}"), ] ) prompt = chat_template.invoke({"text": step_prompt_text}) response = await llm.ainvoke(prompt) step_result = response.content.strip() cleaned_result = remove_markdown(step_result) try: partial_obj = json.loads(cleaned_result) except json.JSONDecodeError as e: raise Exception(f"Failed to parse JSON data for step {idx+1}: {e}") # print(f"Step {idx+1} JSON:\n{json.dumps(partial_obj, indent=2)}\n") return partial_obj print("\n---WRITING THE JSON---\n") initial_prompt = state['initial_prompt'] plan = state['plan'].strip() plan_lines = plan.split('\n') partial_jsons: List[dict] = [] # Build tasks for each step tasks = [] for idx, line in enumerate(plan_lines): if len(line.strip()) > 0: tasks.append(asyncio.create_task(get_partial_json(idx, line))) # Run them concurrently partial_jsons = await asyncio.gather(*tasks) # Store results state['write_steps'] = list(partial_jsons) return state def consolidation_node(state: GraphState) -> GraphState: print("\n---CONSOLIDATING THE JSON---\n") plan = state['plan'] partial_jsons = state['write_steps'] final_obj = { "plan": plan, "steps": partial_jsons } # Convert to string final_json_str = json.dumps(final_obj, ensure_ascii=False, indent=2) # Store it in the state state['final_json'] = final_json_str return state def create_workflow_sync() -> StateGraph: workflow = StateGraph(GraphState) # Add nodes workflow.add_node("planning_node", planning_node) workflow.add_node("writing_node", writing_node_sync) workflow.add_node("consolidation_node", consolidation_node) # Set entry point workflow.set_entry_point("planning_node") # Add edges workflow.add_edge("planning_node", "writing_node") workflow.add_edge("writing_node", "consolidation_node") # Finally, consolidation_node leads to END workflow.add_edge("consolidation_node", END) return workflow.compile() def create_workflow_async() -> StateGraph: workflow = StateGraph(GraphState) # Add nodes workflow.add_node("planning_node", planning_node) workflow.add_node("writing_node", writing_node_async) workflow.add_node("consolidation_node", consolidation_node) # Set entry point workflow.set_entry_point("planning_node") # Add edges workflow.add_edge("planning_node", "writing_node") workflow.add_edge("writing_node", "consolidation_node") # Finally, consolidation_node leads to END workflow.add_edge("consolidation_node", END) return workflow.compile() if __name__ == "__main__": import time test_instruction = "Write a 1500-word piece on the HBO TV show Westworld, covering major characters, \ themes of AI and consciousness, and how the story might have continued had it not been cancelled. \ Include specific details, quotes, and references to the show and its creators.\ Do not include any spoilers for the climax of the show's final season." app = create_workflow_async() # We supply an initial state. # (We only need 'initial_prompt' here; the other fields will be set by nodes.) state_input: GraphState = { "initial_prompt": test_instruction, "plan": "", "write_steps": [], "final_json": "" } start = time.time() final_state = asyncio.run(app.ainvoke(state_input)) end = time.time() # The final JSON is in final_state['final_json'] print("\n===== FINAL JSON OUTPUT =====\n") print(final_state['final_json']) print("=============================\n") print("\n===== PERFOMANCE =====\n") print(f"Time taken: {end-start:.2f} seconds") print("======================\n")