bstraehle commited on
Commit
073d6c4
1 Parent(s): 16ec590

Create assistant.py

Browse files
Files changed (1) hide show
  1. assistant.py +195 -0
assistant.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import yfinance as yf
3
+
4
+ import json, openai, os, time
5
+
6
+ from datetime import date
7
+ from openai import OpenAI
8
+ from tavily import TavilyClient
9
+ from typing import List
10
+ from utils import function_to_schema, show_json
11
+
12
+ openai_client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
13
+ tavily_client = TavilyClient(api_key=os.environ.get("TAVILY_API_KEY"))
14
+
15
+ assistant_id = "asst_DbCpNsJ0vHSSdl6ePlkKZ8wG"
16
+
17
+ assistant, thread = None, None
18
+
19
+ def today_tool() -> str:
20
+ """Returns today's date. Use this function for any questions related to knowing today's date.
21
+ There should be no input. This function always returns today's date."""
22
+ return str(date.today())
23
+
24
+ def yf_download_tool(tickers: List[str], start_date: date, end_date: date) -> pd.DataFrame:
25
+ """Returns historical stock data for a list of given tickers from start date to end date
26
+ using the yfinance library download function.
27
+ Use this function for any questions related to getting historical stock data.
28
+ The input should be the tickers as a List of strings, a start date, and an end date.
29
+ This function always returns a pandas DataFrame."""
30
+ return yf.download(tickers, start=start_date, end=end_date)
31
+
32
+ def tavily_search_tool(query: str) -> str:
33
+ """Searches the web for a given query and returns an answer, "
34
+ ready for use as context in a RAG application, using the Tavily API.
35
+ Use this function for any questions requiring knowledge not available to the model.
36
+ The input should be the query string. This function always returns an answer string."""
37
+ return tavily_client.get_search_context(query=query, max_results=5)
38
+
39
+ tools = {
40
+ "today_tool": today_tool,
41
+ "yf_download_tool": yf_download_tool,
42
+ "tavily_search_tool": tavily_search_tool,
43
+ }
44
+
45
+ def create_assistant(openai_client):
46
+ assistant = openai_client.beta.assistants.create(
47
+ name="Python Coding Assistant",
48
+ instructions=(
49
+ "You are a Python programming language expert that "
50
+ "generates Pylint-compliant code and explains it. "
51
+ "Execute code when explicitly asked to."
52
+ ),
53
+ model="gpt-4o",
54
+ tools=[
55
+ {"type": "code_interpreter"},
56
+ {"type": "function", "function": function_to_schema(today_tool)},
57
+ {"type": "function", "function": function_to_schema(yf_download_tool)},
58
+ {"type": "function", "function": function_to_schema(tavily_search_tool)},
59
+ ],
60
+ )
61
+
62
+ show_json("assistant", assistant)
63
+
64
+ return assistant
65
+
66
+ def load_assistant(openai_client):
67
+ assistant = openai_client.beta.assistants.retrieve(assistant_id)
68
+ show_json("assistant", assistant)
69
+ return assistant
70
+
71
+ def create_thread(openai_client):
72
+ thread = openai_client.beta.threads.create()
73
+ show_json("thread", thread)
74
+ return thread
75
+
76
+ def create_message(openai_client, thread, msg):
77
+ message = openai_client.beta.threads.messages.create(
78
+ role="user",
79
+ thread_id=thread.id,
80
+ content=msg,
81
+ )
82
+
83
+ show_json("message", message)
84
+ return message
85
+
86
+ def create_run(openai_client, assistant, thread):
87
+ run = openai_client.beta.threads.runs.create(
88
+ assistant_id=assistant.id,
89
+ thread_id=thread.id,
90
+ parallel_tool_calls=False,
91
+ )
92
+
93
+ show_json("run", run)
94
+ return run
95
+
96
+ def wait_on_run(openai_client, thread, run):
97
+ while run.status == "queued" or run.status == "in_progress":
98
+ run = openai_client.beta.threads.runs.retrieve(
99
+ thread_id=thread.id,
100
+ run_id=run.id,
101
+ )
102
+
103
+ time.sleep(1)
104
+
105
+ show_json("run", run)
106
+
107
+ if hasattr(run, "last_error") and run.last_error:
108
+ raise gr.Error(run.last_error)
109
+
110
+ return run
111
+
112
+ def get_run_steps(openai_client, thread, run):
113
+ run_steps = openai_client.beta.threads.runs.steps.list(
114
+ thread_id=thread.id,
115
+ run_id=run.id,
116
+ order="asc",
117
+ )
118
+
119
+ show_json("run_steps", run_steps)
120
+ return run_steps
121
+
122
+ def execute_tool_call(tool_call):
123
+ name = tool_call.function.name
124
+ args = {}
125
+
126
+ if len(tool_call.function.arguments) > 10:
127
+ args = json.loads(tool_call.function.arguments)
128
+
129
+ return tools[name](**args)
130
+
131
+ def execute_tool_calls(run_steps):
132
+ run_step_details = []
133
+
134
+ tool_call_ids = []
135
+ tool_call_results = []
136
+
137
+ for step in run_steps.data:
138
+ step_details = step.step_details
139
+ run_step_details.append(step_details)
140
+ show_json("step_details", step_details)
141
+
142
+ if hasattr(step_details, "tool_calls"):
143
+ for tool_call in step_details.tool_calls:
144
+ show_json("tool_call", tool_call)
145
+
146
+ if hasattr(tool_call, "function"):
147
+ tool_call_ids.append(tool_call.id)
148
+ tool_call_results.append(execute_tool_call(tool_call))
149
+
150
+ return tool_call_ids, tool_call_results
151
+
152
+ def get_messages(openai_client, thread):
153
+ messages = openai_client.beta.threads.messages.list(
154
+ thread_id=thread.id
155
+ )
156
+
157
+ show_json("messages", messages)
158
+ return messages
159
+
160
+ def extract_content_values(data):
161
+ text_values, image_values = [], []
162
+
163
+ for item in data.data:
164
+ for content in item.content:
165
+ if content.type == "text":
166
+ text_value = content.text.value
167
+ text_values.append(text_value)
168
+ if content.type == "image_file":
169
+ image_value = content.image_file.file_id
170
+ image_values.append(image_value)
171
+
172
+ return text_values, image_values
173
+
174
+ ###
175
+ def generate_tool_outputs(tool_call_ids, tool_call_results):
176
+ tool_outputs = []
177
+
178
+ for tool_call_id, tool_call_result in zip(tool_call_ids, tool_call_results):
179
+ tool_output = {}
180
+
181
+ try:
182
+ tool_output = {
183
+ "tool_call_id": tool_call_id,
184
+ "output": tool_call_result.to_json()
185
+ }
186
+ except AttributeError:
187
+ tool_output = {
188
+ "tool_call_id": tool_call_id,
189
+ "output": tool_call_result
190
+ }
191
+
192
+ tool_outputs.append(tool_output)
193
+
194
+ return tool_outputs
195
+ ###