|
import logging |
|
import os |
|
from datetime import datetime |
|
from decimal import Decimal |
|
from typing import List |
|
|
|
import boto3 |
|
from boto3.dynamodb.conditions import Attr, Key |
|
from datasets import Dataset |
|
|
|
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO")) |
|
|
|
|
|
dynamodb = boto3.resource('dynamodb', region_name='us-east-1') |
|
|
|
|
|
def _create_arena_table(): |
|
dynamodb.create_table( |
|
TableName='oaaic_chatbot_arena', |
|
KeySchema=[ |
|
{ |
|
'AttributeName': 'arena_battle_id', |
|
'KeyType': 'HASH' |
|
}, |
|
], |
|
AttributeDefinitions=[ |
|
{ |
|
'AttributeName': 'arena_battle_id', |
|
'AttributeType': 'S' |
|
}, |
|
{ |
|
'AttributeName': 'timestamp', |
|
'AttributeType': 'S' |
|
}, |
|
], |
|
ProvisionedThroughput={ |
|
'ReadCapacityUnits': 5, |
|
'WriteCapacityUnits': 5 |
|
}, |
|
GlobalSecondaryIndexes=[ |
|
{ |
|
'IndexName': 'TimestampIndex', |
|
'KeySchema': [ |
|
{ |
|
'AttributeName': 'arena_battle_id', |
|
'KeyType': 'HASH' |
|
}, |
|
{ |
|
'AttributeName': 'timestamp', |
|
'KeyType': 'RANGE' |
|
}, |
|
], |
|
'Projection': { |
|
'ProjectionType': 'ALL', |
|
}, |
|
'ProvisionedThroughput': { |
|
'ReadCapacityUnits': 5, |
|
'WriteCapacityUnits': 5, |
|
} |
|
}, |
|
] |
|
) |
|
|
|
def _create_elo_scores_table(): |
|
dynamodb.create_table( |
|
TableName='elo_scores', |
|
KeySchema=[ |
|
{ |
|
'AttributeName': 'chatbot_name', |
|
'KeyType': 'HASH' |
|
}, |
|
], |
|
AttributeDefinitions=[ |
|
{ |
|
'AttributeName': 'chatbot_name', |
|
'AttributeType': 'S' |
|
}, |
|
], |
|
ProvisionedThroughput={ |
|
'ReadCapacityUnits': 5, |
|
'WriteCapacityUnits': 5 |
|
} |
|
) |
|
|
|
|
|
def _create_elo_logs_table(): |
|
dynamodb.create_table( |
|
TableName='elo_logs', |
|
KeySchema=[ |
|
{ |
|
'AttributeName': 'arena_battle_id', |
|
'KeyType': 'HASH' |
|
}, |
|
{ |
|
'AttributeName': 'battle_timestamp', |
|
'KeyType': 'RANGE' |
|
}, |
|
], |
|
AttributeDefinitions=[ |
|
{ |
|
'AttributeName': 'arena_battle_id', |
|
'AttributeType': 'S' |
|
}, |
|
{ |
|
'AttributeName': 'battle_timestamp', |
|
'AttributeType': 'S' |
|
}, |
|
{ |
|
'AttributeName': 'all', |
|
'AttributeType': 'S' |
|
} |
|
], |
|
ProvisionedThroughput={ |
|
'ReadCapacityUnits': 10, |
|
'WriteCapacityUnits': 10 |
|
}, |
|
GlobalSecondaryIndexes=[ |
|
{ |
|
'IndexName': 'AllTimestampIndex', |
|
'KeySchema': [ |
|
{ |
|
'AttributeName': 'all', |
|
'KeyType': 'HASH' |
|
}, |
|
{ |
|
'AttributeName': 'battle_timestamp', |
|
'KeyType': 'RANGE' |
|
} |
|
], |
|
'Projection': { |
|
'ProjectionType': 'ALL' |
|
}, |
|
'ProvisionedThroughput': { |
|
'ReadCapacityUnits': 10, |
|
'WriteCapacityUnits': 10 |
|
} |
|
}, |
|
] |
|
) |
|
|
|
|
|
def get_unprocessed_battles(last_processed_timestamp): |
|
|
|
table = dynamodb.Table('oaaic_chatbot_arena') |
|
|
|
|
|
response = table.scan( |
|
FilterExpression=Attr('timestamp').gt(last_processed_timestamp), |
|
|
|
) |
|
|
|
return response['Items'] |
|
|
|
|
|
def calculate_elo(rating1, rating2, result, K=32): |
|
|
|
rating1 = float(rating1) |
|
rating2 = float(rating2) |
|
|
|
|
|
expected_outcome1 = 1.0 / (1.0 + 10.0 ** ((rating2 - rating1) / 400.0)) |
|
expected_outcome2 = 1.0 - expected_outcome1 |
|
|
|
|
|
new_rating1 = rating1 + K * (result - expected_outcome1) |
|
new_rating2 = rating2 + K * ((1.0 - result) - expected_outcome2) |
|
|
|
return Decimal(new_rating1).quantize(Decimal('0.00')), Decimal(new_rating2).quantize(Decimal('0.00')) |
|
|
|
|
|
def get_last_processed_timestamp(): |
|
table = dynamodb.Table('elo_logs') |
|
|
|
|
|
response = table.query( |
|
IndexName='AllTimestampIndex', |
|
KeyConditionExpression=Key('all').eq('ALL'), |
|
ScanIndexForward=False, |
|
Limit=1 |
|
) |
|
|
|
|
|
if not response['Items']: |
|
return '1970-01-01T00:00:00' |
|
|
|
|
|
return response['Items'][0]['battle_timestamp'] |
|
|
|
|
|
def log_elo_update(arena_battle_id, battle_timestamp, new_rating1, new_rating2): |
|
|
|
table = dynamodb.Table('elo_logs') |
|
|
|
|
|
table.put_item( |
|
Item={ |
|
'arena_battle_id': arena_battle_id, |
|
'battle_timestamp': battle_timestamp, |
|
'log_timestamp': datetime.now().isoformat(), |
|
'new_rating1': new_rating1, |
|
'new_rating2': new_rating2, |
|
'all': 'ALL', |
|
} |
|
) |
|
|
|
|
|
def get_elo_score(chatbot_name, elo_scores): |
|
if chatbot_name in elo_scores: |
|
return elo_scores[chatbot_name] |
|
|
|
table = dynamodb.Table('elo_scores') |
|
response = table.get_item(Key={'chatbot_name': chatbot_name}) |
|
|
|
|
|
if 'Item' not in response: |
|
return 1500 |
|
|
|
return response['Item']['elo_score'] |
|
|
|
|
|
def update_elo_score(chatbot_name, new_elo_score): |
|
table = dynamodb.Table('elo_scores') |
|
|
|
|
|
table.put_item( |
|
Item={ |
|
'chatbot_name': chatbot_name, |
|
'elo_score': Decimal(str(new_elo_score)), |
|
} |
|
) |
|
|
|
|
|
def get_elo_scores(): |
|
table = dynamodb.Table('elo_scores') |
|
|
|
response = table.scan() |
|
data = response['Items'] |
|
|
|
return data |
|
|
|
|
|
def _backfill_logs(): |
|
table = dynamodb.Table('elo_logs') |
|
|
|
|
|
response = table.scan() |
|
|
|
for item in response['Items']: |
|
table.update_item( |
|
Key={ |
|
'arena_battle_id': item['arena_battle_id'], |
|
'battle_timestamp': item['battle_timestamp'] |
|
}, |
|
UpdateExpression="SET #all = :value", |
|
ExpressionAttributeNames={ |
|
'#all': 'all' |
|
}, |
|
ExpressionAttributeValues={ |
|
':value': 'ALL' |
|
} |
|
) |
|
|
|
def main(): |
|
last_processed_timestamp = get_last_processed_timestamp() |
|
battles: List[dict] = get_unprocessed_battles(last_processed_timestamp) |
|
battles = sorted(battles, key=lambda x: x['timestamp']) |
|
elo_scores = {} |
|
|
|
for battle in battles: |
|
print(repr(battle)) |
|
if battle['label'] in {-1, 0, 1, 2}: |
|
outcome = battle['label'] |
|
for chatbot_name in [battle['choice1_name'], battle['choice2_name']]: |
|
if chatbot_name not in elo_scores: |
|
elo_scores[chatbot_name] = get_elo_score(chatbot_name, elo_scores) |
|
|
|
|
|
|
|
if outcome == 0 or outcome == -1: |
|
elo_result = 0.5 |
|
elif outcome == 1: |
|
elo_result = 1 |
|
else: |
|
elo_result = 0 |
|
|
|
new_rating1, new_rating2 = calculate_elo(elo_scores[battle['choice1_name']], elo_scores[battle['choice2_name']], elo_result) |
|
logging.info(f"{battle['choice1_name']}: {elo_scores[battle['choice1_name']]} -> {new_rating1} | {battle['choice2_name']}: {elo_scores[battle['choice2_name']]} -> {new_rating2}") |
|
elo_scores[battle['choice1_name']] = new_rating1 |
|
elo_scores[battle['choice2_name']] = new_rating2 |
|
log_elo_update(battle['arena_battle_id'], battle['timestamp'], new_rating1, new_rating2) |
|
update_elo_score(battle['choice1_name'], new_rating1) |
|
update_elo_score(battle['choice2_name'], new_rating2) |
|
elo_scores[battle['choice1_name']] = new_rating1 |
|
elo_scores[battle['choice2_name']] = new_rating2 |
|
|
|
elo_scores = get_elo_scores() |
|
for i, j in enumerate(elo_scores): |
|
j["elo_score"] = float(j["elo_score"]) |
|
elo_scores[i] = j |
|
print(elo_scores) |
|
|
|
if battles: |
|
|
|
elo_dataset = Dataset.from_list(elo_scores) |
|
elo_dataset.push_to_hub("openaccess-ai-collective/chatbot-arena-elo-scores", private=False) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|