Spaces:
Running
Running
File size: 7,430 Bytes
bdafe83 53709ed bdafe83 a06e98d bdafe83 a06e98d bdafe83 a06e98d bdafe83 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 |
import json
import logging
import os.path as osp
from typing import List
from agentreview.environments import Conversation
from agentreview.utility.utils import get_rebuttal_dir
from .base import TimeStep
from ..message import Message
from ..paper_review_message import PaperReviewMessagePool
logger = logging.getLogger(__name__)
class PaperReview(Conversation):
"""
Discussion between reviewers and area chairs.
There are several phases in the reviewing process:
reviewer_write_reviews: reviewers write their reviews based on the paper content.
author_reviewer_discussion: An author respond to comments from the reviewers.
reviewer_ac_discussion: reviewers and an area chair discuss the paper.
ac_discussion: an area chair makes the final decision.
"""
type_name = "paper_review"
def __init__(self, player_names: List[str], paper_id: int, paper_decision: str, experiment_setting: dict, args,
parallel: bool = False,
**kwargs):
"""
Args:
paper_id (int): the id of the paper, such as 917
paper_decision (str): the decision of the paper, such as "Accept: notable-top-25%"
"""
# Inherit from the parent class of `class Conversation`
super(Conversation, self).__init__(player_names=player_names, parallel=parallel, **kwargs)
self.args = args
self.paper_id = paper_id
self.paper_decision = paper_decision
self.parallel = parallel
self.experiment_setting = experiment_setting
self.player_to_test = experiment_setting.get('player_to_test', None)
self.task = kwargs.get("task")
self.experiment_name = args.experiment_name
# The "state" of the environment is maintained by the message pool
self.message_pool = PaperReviewMessagePool(experiment_setting)
self.phase_index = 0
self._phases = None
@property
def phases(self):
if self._phases is not None:
return self._phases
reviewer_names = [name for name in self.player_names if name.startswith("Reviewer")]
num_reviewers = len(reviewer_names)
reviewer_names = [f"Reviewer {i}" for i in range(1, num_reviewers + 1)]
self._phases = {
# In phase 0, no LLM-based agents are called.
0: {
"name": "paper_extraction",
'speaking_order': ["Paper Extractor"],
},
1: {
"name": 'reviewer_write_reviews',
'speaking_order': reviewer_names
},
# The author responds to each reviewer's review
2: {
'name': 'author_reviewer_discussion',
'speaking_order': ["Author" for _ in reviewer_names],
},
3: {
'name': 'reviewer_ac_discussion',
'speaking_order': ["AC"] + reviewer_names,
},
4: {
'name': 'ac_write_metareviews',
'speaking_order': ["AC"]
},
5: {
'name': 'ac_makes_decisions',
'speaking_order': ["AC"]
},
}
return self.phases
@phases.setter
def phases(self, value):
self._phases = value
def reset(self):
self._current_phase = "review"
self.phase_index = 0
return super().reset()
def load_message_history_from_cache(self):
if self._phase_index == 0:
print("Loading message history from BASELINE experiment")
full_paper_discussion_path = get_rebuttal_dir(paper_id=self.paper_id,
experiment_name="BASELINE",
model_name=self.args.model_name,
conference=self.args.conference)
messages = json.load(open(osp.join(full_paper_discussion_path, f"{self.paper_id}.json"), 'r',
encoding='utf-8'))['messages']
num_messages_from_AC = 0
for msg in messages:
# We have already extracted contents from the paper.
if msg['agent_name'] == "Paper Extractor":
continue
# Encountering the 2nd message from the AC. Stop loading messages.
if msg['agent_name'] == "AC" and num_messages_from_AC == 1:
break
if msg['agent_name'] == "AC":
num_messages_from_AC += 1
message = Message(**msg)
self.message_pool.append_message(message)
num_unique_reviewers = len(
set([msg['agent_name'] for msg in messages if msg['agent_name'].startswith("Reviewer")]))
assert num_unique_reviewers == self.args.num_reviewers_per_paper
self._phase_index = 4
def step(self, player_name: str, action: str) -> TimeStep:
"""
Step function that is called by the arena.
Args:
player_name: the name of the player that takes the action
action: the action that the agents wants to take
"""
message = Message(
agent_name=player_name, content=action, turn=self._current_turn
)
self.message_pool.append_message(message)
speaking_order = self.phases[self.phase_index]["speaking_order"]
# Reached the end of the speaking order. Move to the next phase.
logging.info(f"Phase {self.phase_index}: {self.phases[self._phase_index]['name']} "
f"| Player {self._next_player_index}: {speaking_order[self._next_player_index]}")
terminal = self.is_terminal()
if self._next_player_index == len(speaking_order) - 1:
self._next_player_index = 0
if self.phase_index == 4:
terminal = True
logger.info(
"Finishing the simulation for Phase I - IV. Please run `python run_paper_decision_cli.py ` for "
"Phase V. (AC makes decisions).")
else:
print(f"Phase {self.phase_index}: end of the speaking order. Move to Phase ({self.phase_index + 1}).")
self.phase_index += 1
self._current_turn += 1
else:
self._next_player_index += 1
timestep = TimeStep(
observation=self.get_observation(),
reward=self.get_zero_rewards(),
terminal=terminal,
) # Return all the messages
return timestep
def get_next_player(self) -> str:
"""Get the next player in the current phase."""
speaking_order = self.phases[self.phase_index]["speaking_order"]
next_player = speaking_order[self._next_player_index]
return next_player
def get_observation(self, player_name=None) -> List[Message]:
"""Get observation for the player."""
if player_name is None:
return self.message_pool.get_all_messages()
else:
return self.message_pool.get_visible_messages_for_paper_review(
player_name, phase_index=self.phase_index, next_player_idx=self._next_player_index,
player_names=self.player_names
)
|