File size: 1,817 Bytes
460d762 d52179b 460d762 d52179b 460d762 d52179b 460d762 d52179b 460d762 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
import re
from typing import List
from src.utils_display import AutoEvalColumn
from src.auto_leaderboard.model_metadata_type import get_model_type
from huggingface_hub import HfApi
import huggingface_hub
api = HfApi()
def get_model_infos_from_hub(leaderboard_data: List[dict]):
for model_data in leaderboard_data:
model_name = model_data["model_name_for_query"]
try:
model_info = api.model_info(model_name)
except huggingface_hub.utils._errors.RepositoryNotFoundError:
model_data[AutoEvalColumn.license.name] = None
model_data[AutoEvalColumn.likes.name] = None
model_data[AutoEvalColumn.params.name] = None
continue
model_data[AutoEvalColumn.license.name] = get_model_license(model_info)
model_data[AutoEvalColumn.likes.name] = get_model_likes(model_info)
model_data[AutoEvalColumn.params.name] = get_model_size(model_name, model_info)
def get_model_license(model_info):
try:
return model_info.cardData["license"]
except Exception:
return None
def get_model_likes(model_info):
return model_info.likes
size_pattern = re.compile(r"\d+(b|m)")
def get_model_size(model_name, model_info):
# In billions
try:
return round(model_info.safetensors["total"] / 1e9, 3)
except AttributeError:
#print(f"Repository {model_id} does not have safetensors weights")
pass
try:
size_match = re.search(size_pattern, model_name.lower())
size = size_match.group(0)
return round(int(size[:-1]) if size[-1] == "b" else int(size[:-1]) / 1e3, 3)
except AttributeError:
return None
def apply_metadata(leaderboard_data: List[dict]):
get_model_type(leaderboard_data)
get_model_infos_from_hub(leaderboard_data)
|