File size: 7,430 Bytes
bdafe83
 
 
 
 
 
53709ed
bdafe83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a06e98d
bdafe83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a06e98d
bdafe83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a06e98d
bdafe83
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import json
import logging
import os.path as osp
from typing import List

from agentreview.environments import Conversation
from agentreview.utility.utils import get_rebuttal_dir
from .base import TimeStep
from ..message import Message
from ..paper_review_message import PaperReviewMessagePool


logger = logging.getLogger(__name__)

class PaperReview(Conversation):
    """
    Discussion between reviewers and area chairs.

    There are several phases in the reviewing process:
        reviewer_write_reviews: reviewers write their reviews based on the paper content.
        author_reviewer_discussion: An author respond to comments from the reviewers.
        reviewer_ac_discussion: reviewers and an area chair discuss the paper.
        ac_discussion: an area chair makes the final decision.
    """

    type_name = "paper_review"

    def __init__(self, player_names: List[str], paper_id: int, paper_decision: str, experiment_setting: dict, args,
                 parallel: bool = False,

                 **kwargs):
        """
        Args:
            paper_id (int): the id of the paper, such as 917
            paper_decision (str): the decision of the paper, such as "Accept: notable-top-25%"
        """

        # Inherit from the parent class of `class Conversation`
        super(Conversation, self).__init__(player_names=player_names, parallel=parallel, **kwargs)
        self.args = args
        self.paper_id = paper_id
        self.paper_decision = paper_decision
        self.parallel = parallel
        self.experiment_setting = experiment_setting
        self.player_to_test = experiment_setting.get('player_to_test', None)
        self.task = kwargs.get("task")
        self.experiment_name = args.experiment_name

        # The "state" of the environment is maintained by the message pool
        self.message_pool = PaperReviewMessagePool(experiment_setting)

        self.phase_index = 0
        self._phases = None

    @property
    def phases(self):

        if self._phases is not None:
            return self._phases
        
        reviewer_names = [name for name in self.player_names if name.startswith("Reviewer")]

        num_reviewers = len(reviewer_names)

        reviewer_names = [f"Reviewer {i}" for i in range(1, num_reviewers + 1)]

        self._phases = {
            # In phase 0, no LLM-based agents are called.
            0: {
                "name": "paper_extraction",
                'speaking_order': ["Paper Extractor"],
            },

            1: {
                "name": 'reviewer_write_reviews',
                'speaking_order': reviewer_names
            },

            # The author responds to each reviewer's review
            2: {
                'name': 'author_reviewer_discussion',
                'speaking_order': ["Author" for _ in reviewer_names],
            },

            3: {
                'name': 'reviewer_ac_discussion',
                'speaking_order': ["AC"] + reviewer_names,
            },

            4: {
                'name': 'ac_write_metareviews',
                'speaking_order': ["AC"]
            },
            5: {
                'name': 'ac_makes_decisions',
                'speaking_order': ["AC"]
            },
        }

        return self.phases

    @phases.setter
    def phases(self, value):
        self._phases = value

    def reset(self):
        self._current_phase = "review"
        self.phase_index = 0
        return super().reset()



    def load_message_history_from_cache(self):
        if self._phase_index == 0:

            print("Loading message history from BASELINE experiment")

            full_paper_discussion_path = get_rebuttal_dir(paper_id=self.paper_id,
                                                          experiment_name="BASELINE",
                                                          model_name=self.args.model_name,
                                                          conference=self.args.conference)

            messages = json.load(open(osp.join(full_paper_discussion_path, f"{self.paper_id}.json"), 'r',
                                      encoding='utf-8'))['messages']

            num_messages_from_AC = 0

            for msg in messages:

                # We have already extracted contents from the paper.
                if msg['agent_name'] == "Paper Extractor":
                    continue

                # Encountering the 2nd message from the AC. Stop loading messages.
                if msg['agent_name'] == "AC" and num_messages_from_AC == 1:
                    break

                if msg['agent_name'] == "AC":
                    num_messages_from_AC += 1

                message = Message(**msg)
                self.message_pool.append_message(message)

            num_unique_reviewers = len(
                set([msg['agent_name'] for msg in messages if msg['agent_name'].startswith("Reviewer")]))

            assert num_unique_reviewers == self.args.num_reviewers_per_paper

            self._phase_index = 4

    def step(self, player_name: str, action: str) -> TimeStep:
        """
        Step function that is called by the arena.

        Args:
            player_name: the name of the player that takes the action
            action: the action that the agents wants to take
        """

        message = Message(
            agent_name=player_name, content=action, turn=self._current_turn
        )
        self.message_pool.append_message(message)

        speaking_order = self.phases[self.phase_index]["speaking_order"]

        # Reached the end of the speaking order. Move to the next phase.
        logging.info(f"Phase {self.phase_index}: {self.phases[self._phase_index]['name']} "
                     f"| Player {self._next_player_index}: {speaking_order[self._next_player_index]}")

        terminal = self.is_terminal()

        if self._next_player_index == len(speaking_order) - 1:
            self._next_player_index = 0

            if self.phase_index == 4:
                terminal = True
                logger.info(
                    "Finishing the simulation for Phase I - IV. Please run `python run_paper_decision_cli.py ` for "
                           "Phase V. (AC makes decisions).")

            else:
                print(f"Phase {self.phase_index}: end of the speaking order. Move to Phase ({self.phase_index + 1}).")
                self.phase_index += 1
                self._current_turn += 1

        else:
            self._next_player_index += 1

        timestep = TimeStep(
            observation=self.get_observation(),
            reward=self.get_zero_rewards(),
            terminal=terminal,
        )  # Return all the messages

        return timestep

    def get_next_player(self) -> str:
        """Get the next player in the current phase."""
        speaking_order = self.phases[self.phase_index]["speaking_order"]   
        next_player = speaking_order[self._next_player_index]
        return next_player

    def get_observation(self, player_name=None) -> List[Message]:
        """Get observation for the player."""
        if player_name is None:
            return self.message_pool.get_all_messages()
        else:

            return self.message_pool.get_visible_messages_for_paper_review(
                player_name, phase_index=self.phase_index, next_player_idx=self._next_player_index,
                player_names=self.player_names
            )