File size: 2,609 Bytes
71fd9c5
7762f99
71fd9c5
7762f99
2b4b309
 
 
 
 
 
 
 
6fc91c7
2b4b309
bddcd7b
71fd9c5
 
fd936a6
2b4b309
 
 
 
 
 
 
 
 
fd936a6
 
 
2b4b309
 
 
 
 
 
fd936a6
a69bbb8
9b4773a
6fc91c7
 
 
 
 
 
 
5fca25d
 
6fc91c7
5fca25d
291ad35
 
 
 
 
 
 
 
 
 
 
6fc91c7
 
9ac3da0
 
6fc91c7
0d28c87
 
 
 
6fc91c7
40e000b
 
9ac3da0
 
 
0d28c87
 
9ac3da0
 
 
ff3c0c2
9ac3da0
ff3c0c2
 
7762f99
 
 
 
cea3391
 
 
 
 
7762f99
cea3391
 
7762f99
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
from typing import Union

import argilla as rg
import gradio as gr
from gradio.oauth import (
    OAUTH_CLIENT_ID,
    OAUTH_CLIENT_SECRET,
    OAUTH_SCOPES,
    OPENID_PROVIDER_URL,
    get_space,
)
from huggingface_hub import whoami

HF_TOKENS = [os.getenv("HF_TOKEN")] + [os.getenv(f"HF_TOKEN_{i}") for i in range(1, 10)]
HF_TOKENS = [token for token in HF_TOKENS if token]

_CHECK_IF_SPACE_IS_SET = (
    all(
        [
            OAUTH_CLIENT_ID,
            OAUTH_CLIENT_SECRET,
            OAUTH_SCOPES,
            OPENID_PROVIDER_URL,
        ]
    )
    or get_space() is None
)

if _CHECK_IF_SPACE_IS_SET:
    from gradio.oauth import OAuthToken
else:
    OAuthToken = str


def get_login_button():
    return gr.LoginButton(
        value="Sign in with Hugging Face!", size="lg", scale=2
    ).activate()


def get_duplicate_button():
    if get_space() is not None:
        return gr.DuplicateButton(size="lg")


def list_orgs(oauth_token: OAuthToken = None):
    if oauth_token is None:
        return []
    data = whoami(oauth_token.token)
    if data["auth"]["type"] == "oauth":
        organisations = [data["name"]] + [org["name"] for org in data["orgs"]]
    else:
        organisations = [
            entry["entity"]["name"]
            for entry in data["auth"]["accessToken"]["fineGrained"]["scoped"]
            if "repo.write" in entry["permissions"]
        ]
        organisations = [org for org in organisations if org != data["name"]]
        organisations = [data["name"]] + organisations
    return organisations


def get_org_dropdown(oauth_token: OAuthToken = None):
    orgs = list_orgs(oauth_token)
    return gr.Dropdown(
        label="Organization",
        choices=orgs,
        value=orgs[0] if orgs else None,
        allow_custom_value=True,
    )


def get_token(oauth_token: OAuthToken = None):
    if oauth_token:
        return oauth_token.token
    else:
        return ""


def swap_visibilty(oauth_token: OAuthToken = None):
    if oauth_token:
        return gr.update(elem_classes=["main_ui_logged_in"])
    else:
        return gr.update(elem_classes=["main_ui_logged_out"])


def get_argilla_client() -> Union[rg.Argilla, None]:
    try:
        api_url = os.getenv("ARGILLA_API_URL_SDG_REVIEWER")
        api_key = os.getenv("ARGILLA_API_KEY_SDG_REVIEWER")
        if api_url is None or api_key is None:
            api_url = os.getenv("ARGILLA_API_URL")
            api_key = os.getenv("ARGILLA_API_KEY")
        return rg.Argilla(
            api_url=api_url,
            api_key=api_key,
        )
    except Exception:
        return None