Cartinoe5930 commited on
Commit
d816c1c
·
1 Parent(s): 312ba7c

Create model_inference.py

Browse files
Files changed (1) hide show
  1. model_inference.py +110 -0
model_inference.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import openai
3
+ import json
4
+ import numpy as np
5
+ import time
6
+
7
+ def load_json(prompt_path, endpoint_path):
8
+ with open(prompt_path, "r") as prompt_file:
9
+ prompt_dict = json.load(prompt_file)
10
+
11
+ with open(endpoint_path, "r") as endpoint_file:
12
+ endpoint_dict = json.load(endpoint_file)
13
+
14
+ return prompt_dict, endpoint_dict
15
+
16
+ def construct_message(agents, instruction, idx):
17
+ if len(agents) == 0:
18
+ prompt = "Can you double check that your answer is correct. Please reiterate your answer, making sure to state your answer at the end of the response."
19
+ return prompt
20
+
21
+ contexts = [agents[0][idx]['content'], agents[1][idx]['content'], agents[2][idx]['content']]
22
+
23
+ # system prompt & user prompt for gpt-3.5-turbo
24
+ sys_prompt = f"I want you to act as a summarizer. You can look at multiple responses and summarize the main points of them so that the meaning is not lost. Multiple responses will be given, which are responses from several different models to a single question. And you should use your excellent summarizing skills to output the best summary."
25
+ summarize_prompt = f"[Response 1]: {contexts[0]}\n[Response 2]: {contexts[1]}\nResponse 3: {contexts[2]}\n\nThese are response of each model to a certain question. Summarize comprehensively without compromising the meaning of each response."
26
+
27
+ message = [
28
+ {"role": "system", "content": sys_prompt},
29
+ {"role": "user", "content": summarize_prompt},
30
+ ]
31
+
32
+ completion = openai.ChatCompletion.create(
33
+ model="gpt-3.5-turbo-16k-0613",
34
+ messages=message,
35
+ max_tokens=256,
36
+ n=1
37
+ )
38
+
39
+ prefix_string = f"This is the summarization of recent/updated opinions from other agents: {completion}"
40
+ prefix_string = prefix_string + "\n\n Use this summarization carefully as additional advice, can you provide an updated answer? Make sure to state your answer at the end of the response." + instruction
41
+ return prefix_string
42
+
43
+ def generate_question(agents, question):
44
+ agent_contexts = [[{"model": agent, "content": question}] for agent in agents]
45
+
46
+ content = agent_contexts[0][0]["content"]
47
+
48
+ return agent_contexts, content
49
+
50
+ def generate_answer(model, formatted_prompt):
51
+ API_URL = endpoint_dict[model]
52
+ headers = {"Authorization": f"Bearer {args.auth_token}"}
53
+ payload = {"inputs": formatted_prompt}
54
+ try:
55
+ resp = requests.post(API_URL, json=payload, headers=headers)
56
+ response = resp.json()
57
+ except:
58
+ print("retrying due to an error......")
59
+ time.sleep(5)
60
+ return generate_answer(API_URL, headers, payload)
61
+
62
+ return {"model": model, "content": response[0]["generated_text"].split(prompt_dict[model]["response_split"])[-1]}
63
+
64
+ def prompt_formatting(model, instruction, cot):
65
+ if model == "alpaca" or model == "orca":
66
+ prompt = prompt_dict[model]["prompt_no_input"]
67
+ else:
68
+ prompt = prompt_dict[model]["prompt"]
69
+
70
+ if cot:
71
+ instruction += "Let's think step by step."
72
+
73
+ return {"model": model, "content": prompt.format(instruction)}
74
+
75
+ def Inference(model_list, question, API_KEY, auth_token, round, cot):
76
+ openai.api_key = API_KEY
77
+
78
+ prompt_dict, endpoint_dict = load_json("src/prompt_template.json", "src/inference_endpoint.json")
79
+
80
+ agents = len(model_list)
81
+ rounds = round
82
+
83
+ generated_description = []
84
+
85
+ agent_contexts, content = generate_question(agents=model_list, question=question)
86
+
87
+ # Debate
88
+ for debate in range(rounds+1):
89
+ # Refer to the summarized previous response
90
+ if debate != 0:
91
+ message = construct_message(agent_contexts, content, 2 * debate - 1)
92
+ for i in range(agent_contexts):
93
+ agent_contexts[i].append(prompt_formatting(agent_contexts[i][-1]["model"], message, args.cot))
94
+
95
+ # Generate new response based on summarized response
96
+ for agent_context in agent_contexts:
97
+ completion = generate_answer(agent_context[-1]["model"], agent_context[-1]["content"] if debate != 0 else prompt_formatting(agent_context[-1]["model"], agent_context[-1]["content"], args.cot)["content"])
98
+ agent_context.append(completion)
99
+
100
+ models_response = {
101
+ f"{args.m1}": [agent_contexts[0][1]["content"], agent_contexts[0][3]["content"], agent_contexts[0][-1]["content"]],
102
+ f"{args.m2}": [agent_contexts[1][1]["content"], agent_contexts[1][3]["content"], agent_contexts[1][-1]["content"]],
103
+ f"{args.m3}": [agent_contexts[2][1]["content"], agent_contexts[2][3]["content"], agent_contexts[2][-1]["content"]]
104
+ }
105
+ response_summarization = [
106
+ agent_contexts[0][2], agent_contexts[0][4]
107
+ ]
108
+ generated_description.append({"question": content, "agent_response": models_response, "summarization": response_summarization})
109
+
110
+ return generated_description