File size: 7,910 Bytes
0169c8b
 
6589e60
 
 
 
 
b31f6f5
6589e60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27bb768
bcc9a7a
762baf0
 
6589e60
 
 
 
 
762baf0
 
6589e60
 
 
 
 
 
 
 
b31f6f5
 
 
6589e60
 
172e154
 
 
6589e60
 
 
172e154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6589e60
 
 
 
 
172e154
 
 
6589e60
 
 
172e154
 
 
 
6589e60
 
 
2f1a9ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import gradio as gr

# cell 1
from typing import Annotated
from langchain_experimental.tools import PythonREPLTool, PythonAstREPLTool
import pandas as pd
import statsmodels as sm
import os

# df = pd.read_csv("HOUST.csv")
df = pd.read_csv("USSTHPI.csv")
python_repl_tool = PythonAstREPLTool(locals={"df": df})

# cell 2
from langchain.agents import AgentExecutor, create_openai_tools_agent
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder, HumanMessagePromptTemplate
import functools
import operator
from typing import Sequence, TypedDict

system_prompt = """You are working with a pandas dataframe in Python. The name of the dataframe is `df`.
                It is important to understand the attributes of the dataframe before working with it. This is the result of running `df.head().to_markdown()`

                <df>
                {dhead}
                </df>

                You are not meant to use only these rows to answer questions - they are meant as a way of telling you about the shape and schema of the dataframe. 
                You also do not have use only the information here to answer questions - you can run intermediate queries to do exporatory data analysis to give you more information as needed. """
system_prompt = system_prompt.format(dhead=df.head().to_markdown())

# The agent state is the input to each node in the graph
class AgentState(TypedDict):
    # The annotation tells the graph that new messages will always be added to the current states
    messages: Annotated[Sequence[BaseMessage], operator.add]
    # The 'next' field indicates where to route to next
    next: str

# part of the problem might be that I'm passing a PromptTemplate object for the system_prompt here
# not everything needs to be an openai tools agent
def create_agent(llm: ChatOpenAI, tools: list, task: str):
    # Each worker node will be given a name and some tools.
    prompt = ChatPromptTemplate.from_messages(
        [
            ( "system", system_prompt, ), # using a global system_prompt
            HumanMessage(content=task),
            MessagesPlaceholder(variable_name="messages"),
            MessagesPlaceholder(variable_name="agent_scratchpad"),
        ]
    )
    agent = create_openai_tools_agent(llm, tools, prompt)
    # for debugging
    # executor = AgentExecutor(agent=agent, tools=tools, verbose=True, return_intermediate_steps=True)
    executor = AgentExecutor(agent=agent, tools=tools)
    return executor

# AIMessage will have all kinds of metadata, so treat it all as HumanMessage I suppose?
def agent_node(state: AgentState, agent, name):
    result = agent.invoke(state)
    return {"messages": [HumanMessage(content=result["output"], name=name)]}
    # return {"messages": [result]}

# I need to write the message to state here? or is that handled automatically?
def chain_node(state: AgentState, chain, name):

    result = chain.invoke(input={"detail": "medium", "messages": state["messages"]})
    return {"messages": [HumanMessage(content=result.content, name=name)]}

# cell 3
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
llm = ChatOpenAI(model="gpt-4o-mini-2024-07-18", temperature=0, api_key=OPENAI_API_KEY)
llm_big = ChatOpenAI(model="gpt-4o", temperature=0, api_key=OPENAI_API_KEY)

eda_task = """Using the data in the dataframe `df` and the package statsmodels, first run an augmented dickey fuller test on the data.
            Using matplotlib plot the time series, display it and save it to 'plot.png'.
            Next use the statsmodel package to generate an ACF plot with zero flag set to False, display it and save it to 'acf.png'.
            Then use the statsmodel package to generate a PACF plot with zero flag set to False, display it and save it to 'pacf.png'"""
eda_agent = create_agent(llm, [python_repl_tool], task=eda_task,)
eda_node = functools.partial(agent_node, agent=eda_agent, name="EDA")

difference_task = """Using the data in the dataframe `df` determine whether a log transformation is appropriate. 
                    If a log transformation is appropriate generate a new column for the log of the series and use this data for analysis.
                    Then determine whether a linear difference is needed and if needed generate a new column for the differenced data.
                    If the data was differenced use the differenced data for analysis."""
diff_agent = create_agent(llm, [python_repl_tool], task=difference_task, )
diff_node = functools.partial(agent_node, agent=diff_agent, name="difference")

plot_template = ChatPromptTemplate.from_messages(
    messages=[
        SystemMessage(content="""Determine whether this time series is stationary or needs to be differenced? 
                      Consider the results of the ADF test along with the plot of the time series, the ACF plot and the PACF plot."""),
        MessagesPlaceholder(variable_name="messages"),
        HumanMessagePromptTemplate.from_template(
            template=[{"type": "image_url", "image_url": {"path": "plot.png"}},
                        {"type": "image_url", "image_url": {"path": "acf.png"}},
                        {"type": "image_url", "image_url": {"path": "pacf.png"}}]),
    ]
)

plot_chain = plot_template | llm_big
plot_node = functools.partial(chain_node, chain=plot_chain, name="PlotAnalysis")

def router(state):
    router_template = ChatPromptTemplate.from_messages(
        messages=[
            MessagesPlaceholder(variable_name="messages"),
            HumanMessage("""If the time series is stationary, return true if it is not stationary return false. 
                         Just return true or false, nothing else.""")
        ]
    )

    router_chain = router_template | llm
    response = router_chain.invoke({"messages": state["messages"]})

    if response.content=="true":
        return "ARIMA"
    else:
        return "Difference"

arima_task = """Using the data in the dataframe `df` and the package statsmodels. 
    Estimate an ARIMA model with the appropriate AR and MA terms. 
    Then display the model results.
    Finally generate an autocorrelation and partial autocorrelation plot of the model residuals with zero flag set to False, display it and save it as 'resid_acf.png'"""

arima_agent = create_agent(llm, [python_repl_tool], task=arima_task,)
arima_node = functools.partial(agent_node, agent=arima_agent, name="ARIMA")

from langgraph.graph import END, StateGraph, START

# add a chain to the node to analyze the ACF plot?
workflow = StateGraph(AgentState)
workflow.add_node("EDA", eda_node)
workflow.add_node("PlotAnalysis", plot_node)
workflow.add_node("Difference", diff_node)
workflow.add_node("ARIMA", arima_node)

# conditional_edge to refit and the loop refit with resid?
workflow.add_edge(START, "EDA")
workflow.add_edge("EDA", "PlotAnalysis")
workflow.add_conditional_edges("PlotAnalysis", router)
workflow.add_edge("Difference", "EDA")
workflow.add_edge("ARIMA", END)

graph = workflow.compile()

from langgraph_sdk import get_client

# Initialize the LangGraph client
client = get_client(url="https://huggingface.co/spaces/pwilczewski/gradiobox")
assistant_id = "graph"

async def stream_response(input_data):
    thread = await client.threads.create()
    async for chunk in client.runs.stream(
        thread["thread_id"],
        assistant_id,
        input=input_data,
        stream_mode="values"
    ):
        yield chunk.data  # Yield the data as it is received

def gradio_interface(input_text):
    # resp = graph.invoke({"messages": [HumanMessage(content="Run the analysis")]}) # debug=True
    input_data = {"messages": [HumanMessage(content="Run the analysis")]}
    return stream_response(input_data)

demo = gr.Interface(fn=gradio_interface, inputs="text", outputs="text", live=True)
demo.launch()