Spaces:
Running
Running
File size: 4,720 Bytes
bdafe83 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
"""
Datastore module for chat_arena.
This module provides utilities for storing the messages and the game results into database.
Currently, it supports Supabase.
"""
import json
import os
import uuid
from typing import List
from .arena import Arena
from .message import Message
# Attempt importing Supabase
try:
import supabase
# Get the Supabase URL and secret key from environment variables
SUPABASE_URL = os.environ.get("SUPABASE_URL", "")
SUPABASE_SECRET_KEY = os.environ.get("SUPABASE_SECRET_KEY", "")
assert SUPABASE_URL and SUPABASE_SECRET_KEY
except Exception:
supabase_available = False
else:
supabase_available = True
# Store the messages into the Supabase database
class SupabaseDB:
def __init__(self):
assert supabase_available and SUPABASE_URL and SUPABASE_SECRET_KEY
supabase_client = supabase.create_client(SUPABASE_URL, SUPABASE_SECRET_KEY)
self.client = supabase_client
# Save Arena state to Supabase
def save_arena(self, arena: Arena):
# Save the environment config
self._save_environment(arena)
# Save the player configs
self._save_player_configs(arena)
# Save the messages
self.save_messages(arena)
# Save the environment config of the arena
def _save_environment(self, arena: Arena):
env = arena.environment
env_config = env.to_config()
moderator_config = env_config.pop("moderator", None)
arena_row = {
"arena_id": str(arena.uuid),
"global_prompt": arena.global_prompt,
"env_type": env_config["env_type"],
"env_config": json.dumps(env_config),
}
self.client.table("Arena").insert(arena_row).execute()
# Get the moderator config
if moderator_config:
moderator_row = {
"moderator_id": str(
uuid.uuid5(arena.uuid, json.dumps(moderator_config))
),
"arena_id": str(arena.uuid),
"role_desc": moderator_config["role_desc"],
"terminal_condition": moderator_config["terminal_condition"],
"backend_type": moderator_config["backend"]["backend_type"],
"temperature": moderator_config["backend"]["temperature"],
"max_tokens": moderator_config["backend"]["max_tokens"],
}
self.client.table("Moderator").insert(moderator_row).execute()
# Save the player configs of the arena
def _save_player_configs(self, arena: Arena):
player_rows = []
for player in arena.players:
player_config = player.to_config()
player_row = {
"player_id": str(uuid.uuid5(arena.uuid, json.dumps(player_config))),
"arena_id": str(arena.uuid),
"name": player.name,
"role_desc": player_config["role_desc"],
"backend_type": player_config["backend"]["backend_type"],
"temperature": player_config["backend"].get("temperature", None),
"max_tokens": player_config["backend"].get("max_tokens", None),
}
player_rows.append(player_row)
self.client.table("Player").insert(player_rows).execute()
# Save the messages
def save_messages(self, arena: Arena, messages: List[Message] = None):
if messages is None:
messages = arena.environment.get_observation()
# Filter messages that are already logged
messages = [msg for msg in messages if not msg.logged]
message_rows = []
for message in messages:
message_row = {
"message_id": str(uuid.uuid5(arena.uuid, message.msg_hash)),
"arena_id": str(arena.uuid),
"agent_name": message.agent_name,
"content": message.content,
"turn": message.turn,
"timestamp": str(message.timestamp),
"msg_type": message.msg_type,
"visible_to": json.dumps(message.visible_to),
}
message_rows.append(message_row)
self.client.table("Message").insert(message_rows).execute()
# Mark the messages as logged
for message in messages:
message.logged = True
# Log the arena results into the Supabase database
def log_arena(arena: Arena, database=None):
if database is None:
pass
else:
database.save_arena(arena)
# Log the messages into the Supabase database
def log_messages(arena: Arena, messages: List[Message], database=None):
if database is None:
pass
else:
database.save_messages(arena, messages)
|