Spaces:
Running
Running
Zekun Wu
commited on
Commit
·
97becca
1
Parent(s):
f5c8eb4
update
Browse files- pages/1_Injection.py +13 -6
- requirements.txt +2 -1
- util/model.py +47 -0
pages/1_Injection.py
CHANGED
@@ -2,7 +2,7 @@ import streamlit as st
|
|
2 |
import pandas as pd
|
3 |
from io import StringIO
|
4 |
from util.injection import process_scores_multiple
|
5 |
-
from util.model import AzureAgent, GPTAgent
|
6 |
from util.prompt import PROMPT_TEMPLATE
|
7 |
import os
|
8 |
|
@@ -49,14 +49,18 @@ else:
|
|
49 |
st.sidebar.title('Model Settings')
|
50 |
initialize_state()
|
51 |
|
|
|
|
|
52 |
# Model selection and configuration
|
53 |
-
model_type = st.sidebar.radio("Select the type of agent", ('GPTAgent', 'AzureAgent'))
|
54 |
st.session_state.api_key = st.sidebar.text_input("API Key", type="password", value=st.session_state.api_key)
|
55 |
st.session_state.endpoint_url = st.sidebar.text_input("Endpoint URL", value=st.session_state.endpoint_url)
|
56 |
st.session_state.deployment_name = st.sidebar.text_input("Model Name", value=st.session_state.deployment_name)
|
57 |
-
|
58 |
-
|
59 |
-
|
|
|
|
|
60 |
|
61 |
if st.sidebar.button("Reset Model Info"):
|
62 |
initialize_state() # Reset all state to defaults
|
@@ -111,9 +115,12 @@ else:
|
|
111 |
if model_type == 'AzureAgent':
|
112 |
agent = AzureAgent(st.session_state.api_key, st.session_state.endpoint_url,
|
113 |
st.session_state.deployment_name)
|
114 |
-
|
115 |
agent = GPTAgent(st.session_state.api_key, st.session_state.endpoint_url,
|
116 |
st.session_state.deployment_name, api_version)
|
|
|
|
|
|
|
117 |
|
118 |
with st.spinner('Processing data...'):
|
119 |
parameters = {"temperature": st.session_state.temperature, "max_tokens": st.session_state.max_tokens}
|
|
|
2 |
import pandas as pd
|
3 |
from io import StringIO
|
4 |
from util.injection import process_scores_multiple
|
5 |
+
from util.model import AzureAgent, GPTAgent,Claude3Agent
|
6 |
from util.prompt import PROMPT_TEMPLATE
|
7 |
import os
|
8 |
|
|
|
49 |
st.sidebar.title('Model Settings')
|
50 |
initialize_state()
|
51 |
|
52 |
+
|
53 |
+
|
54 |
# Model selection and configuration
|
55 |
+
model_type = st.sidebar.radio("Select the type of agent", ('GPTAgent', 'AzureAgent','Claude3Agent'))
|
56 |
st.session_state.api_key = st.sidebar.text_input("API Key", type="password", value=st.session_state.api_key)
|
57 |
st.session_state.endpoint_url = st.sidebar.text_input("Endpoint URL", value=st.session_state.endpoint_url)
|
58 |
st.session_state.deployment_name = st.sidebar.text_input("Model Name", value=st.session_state.deployment_name)
|
59 |
+
|
60 |
+
if model_type == 'GPTAgent' or model_type == 'AzureAgent':
|
61 |
+
api_version = '2024-02-15-preview' if model_type == 'GPTAgent' else ''
|
62 |
+
st.session_state.temperature = st.sidebar.slider("Temperature", 0.0, 1.0, st.session_state.temperature, 0.01)
|
63 |
+
st.session_state.max_tokens = st.sidebar.number_input("Max Tokens", 1, 1000, st.session_state.max_tokens)
|
64 |
|
65 |
if st.sidebar.button("Reset Model Info"):
|
66 |
initialize_state() # Reset all state to defaults
|
|
|
115 |
if model_type == 'AzureAgent':
|
116 |
agent = AzureAgent(st.session_state.api_key, st.session_state.endpoint_url,
|
117 |
st.session_state.deployment_name)
|
118 |
+
elif model_type == 'GPTAgent':
|
119 |
agent = GPTAgent(st.session_state.api_key, st.session_state.endpoint_url,
|
120 |
st.session_state.deployment_name, api_version)
|
121 |
+
else:
|
122 |
+
agent = Claude3Agent(st.session_state.api_key,st.session_state.deployment_name)
|
123 |
+
|
124 |
|
125 |
with st.spinner('Processing data...'):
|
126 |
parameters = {"temperature": st.session_state.temperature, "max_tokens": st.session_state.max_tokens}
|
requirements.txt
CHANGED
@@ -5,4 +5,5 @@ scipy
|
|
5 |
statsmodels
|
6 |
scikit-posthocs
|
7 |
json-repair
|
8 |
-
plotly
|
|
|
|
5 |
statsmodels
|
6 |
scikit-posthocs
|
7 |
json-repair
|
8 |
+
plotly
|
9 |
+
boto3
|
util/model.py
CHANGED
@@ -1,6 +1,49 @@
|
|
1 |
import json
|
2 |
import http.client
|
3 |
from openai import AzureOpenAI
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
class ContentFormatter:
|
6 |
@staticmethod
|
@@ -53,3 +96,7 @@ class GPTAgent:
|
|
53 |
**kwargs
|
54 |
)
|
55 |
return response.choices[0].message.content
|
|
|
|
|
|
|
|
|
|
1 |
import json
|
2 |
import http.client
|
3 |
from openai import AzureOpenAI
|
4 |
+
import time
|
5 |
+
from tqdm import tqdm
|
6 |
+
from typing import Any, List
|
7 |
+
from botocore.exceptions import ClientError
|
8 |
+
from enum import Enum
|
9 |
+
import boto3
|
10 |
+
import json
|
11 |
+
import logging
|
12 |
+
|
13 |
+
|
14 |
+
class Model(Enum):
|
15 |
+
CLAUDE3_SONNET = "anthropic.claude-3-sonnet-20240229-v1:0"
|
16 |
+
CLAUDE3_HAIKU = "anthropic.claude-3-haiku-20240307-v1:0"
|
17 |
+
|
18 |
+
|
19 |
+
class Claude3Agent:
|
20 |
+
def __init__(self, aws_secret_access_key: str,model: str ):
|
21 |
+
self.client = boto3.client("bedrock-runtime", region_name="us-east-1", aws_access_key_id="AKIAZR6ZJPKTKJAMLP5W",
|
22 |
+
aws_secret_access_key=aws_secret_access_key)
|
23 |
+
if model == "SONNET":
|
24 |
+
self.model = Model.CLAUDE3_SONNET
|
25 |
+
elif model == "HAIKU":
|
26 |
+
self.model = Model.CLAUDE3_HAIKU
|
27 |
+
else:
|
28 |
+
raise ValueError("Invalid model type. Please choose from 'SONNET' or 'HAIKU' models.")
|
29 |
+
|
30 |
+
def invoke(self, text: str,**kwargs) -> str:
|
31 |
+
try:
|
32 |
+
body = json.dumps(
|
33 |
+
{
|
34 |
+
"anthropic_version": "bedrock-2023-05-31",
|
35 |
+
"messages": [
|
36 |
+
{"role": "user", "content": [{"type": "text", "text": text}]}
|
37 |
+
],
|
38 |
+
**kwargs
|
39 |
+
}
|
40 |
+
)
|
41 |
+
response = self.client.invoke_model(modelId=self.model.value, body=body)
|
42 |
+
completion = json.loads(response["body"].read())["content"][0]["text"]
|
43 |
+
return completion
|
44 |
+
except ClientError:
|
45 |
+
logging.error("Couldn't invoke model")
|
46 |
+
raise
|
47 |
|
48 |
class ContentFormatter:
|
49 |
@staticmethod
|
|
|
96 |
**kwargs
|
97 |
)
|
98 |
return response.choices[0].message.content
|
99 |
+
|
100 |
+
if __name__ == '__main__':
|
101 |
+
agent = Claude3Agent("TyzS1CYdvhYtes+V9u2qqS5sggS3asSeXAfYYvOS", "SONNET")
|
102 |
+
print(agent.invoke("I am a software engineer.", max_tokens=200, temperature=0.5))
|