File size: 3,110 Bytes
ce00289
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os
from typing import List, Optional

import networkx as nx
import streamlit.components.v1 as components

from llm_transparency_tool.models.transparent_llm import ModelInfo
from llm_transparency_tool.server.graph_selection import GraphSelection, UiGraphNode

_RELEASE = True

if _RELEASE:
    parent_dir = os.path.dirname(os.path.abspath(__file__))
    config = {
        "path": os.path.join(parent_dir, "frontend/build"),
    }
else:
    config = {
        "url": "http://localhost:3001",
    }

_component_func = components.declare_component("contribution_graph", **config)


def is_node_valid(node: UiGraphNode, n_layers: int, n_tokens: int):
    return node.layer < n_layers and node.token < n_tokens


def is_selection_valid(s: GraphSelection, n_layers: int, n_tokens: int):
    if not s:
        return True
    if s.node:
        if not is_node_valid(s.node, n_layers, n_tokens):
            return False
    if s.edge:
        for node in [s.edge.source, s.edge.target]:
            if not is_node_valid(node, n_layers, n_tokens):
                return False
    return True


def contribution_graph(
    model_info: ModelInfo,
    tokens: List[str],
    graphs: List[nx.Graph],
    key: str,
) -> Optional[GraphSelection]:
    """Create a new instance of contribution graph.

    Returns selected graph node or None if nothing was selected.
    """
    assert len(tokens) == len(graphs)

    result = _component_func(
        component="graph",
        model_info=model_info.__dict__,
        tokens=tokens,
        edges_per_token=[nx.node_link_data(g)["links"] for g in graphs],
        default=None,
        key=key,
    )

    selection = GraphSelection.from_json(result)

    n_tokens = len(tokens)
    n_layers = model_info.n_layers
    # We need this extra protection because even though the component has to check for
    # the validity of the selection, sometimes it allows invalid output. It's some
    # unexpected effect that has something to do with React and how the output value is
    # set for the component.
    if not is_selection_valid(selection, n_layers, n_tokens):
        selection = None

    return selection


def selector(
    items: List[str],
    indices: List[int],
    temperatures: Optional[List[float]],
    preselected_index: Optional[int],
    key: str,
) -> Optional[int]:
    """Create a new instance of selector.

    Returns selected item index.
    """
    n = len(items)
    assert n == len(indices)
    items = [{"index": i, "text": s} for s, i in zip(items, indices)]

    if temperatures is not None:
        assert n == len(temperatures)
        for i, t in enumerate(temperatures):
            items[i]["temperature"] = t

    result = _component_func(
        component="selector",
        items=items,
        preselected_index=preselected_index,
        default=None,
        key=key,
    )

    return None if result is None else int(result)