|
import spaces |
|
import gradio as gr |
|
from phi3_instruct_graph import MODEL_LIST, Phi3InstructGraph |
|
from textwrap import dedent |
|
import rapidjson |
|
import spaces |
|
from pyvis.network import Network |
|
import networkx as nx |
|
import spacy |
|
from spacy import displacy |
|
from spacy.tokens import Span |
|
import random |
|
|
|
json_example = {'nodes': [{'id': 'Aerosmith', 'type': 'organization', 'detailed_type': 'rock band'}, {'id': 'Steven Tyler', 'type': 'person', 'detailed_type': 'lead singer'}, {'id': 'vocal cord injury', 'type': 'medical condition', 'detailed_type': 'fractured larynx'}, {'id': 'retirement', 'type': 'event', 'detailed_type': 'announcement'}, {'id': 'touring', 'type': 'activity', 'detailed_type': 'musical performance'}, {'id': 'September 2023', 'type': 'date', 'detailed_type': 'specific time'}], 'edges': [{'from': 'Aerosmith', 'to': 'Steven Tyler', 'label': 'led by'}, {'from': 'Steven Tyler', 'to': 'vocal cord injury', 'label': 'suffered'}, {'from': 'vocal cord injury', 'to': 'retirement', 'label': 'caused'}, {'from': 'retirement', 'to': 'touring', 'label': 'ended'}, {'from': 'vocal cord injury', 'to': 'September 2023', 'label': 'occurred in'}]} |
|
|
|
@spaces.GPU |
|
def extract(text, model): |
|
model = Phi3InstructGraph(model=model) |
|
result = model.extract(text) |
|
return rapidjson.loads(result) |
|
|
|
def handle_text(text): |
|
return " ".join(text.split()) |
|
|
|
def get_random_color(): |
|
return f"#{random.randint(0, 0xFFFFFF):06x}" |
|
|
|
def get_random_light_color(): |
|
|
|
r = random.randint(128, 255) |
|
g = random.randint(128, 255) |
|
b = random.randint(128, 255) |
|
return f"#{r:02x}{g:02x}{b:02x}" |
|
|
|
def get_random_color(): |
|
return f"#{random.randint(0, 0xFFFFFF):06x}" |
|
|
|
def find_token_indices(doc, substring, text): |
|
result = [] |
|
start_index = text.find(substring) |
|
|
|
while start_index != -1: |
|
end_index = start_index + len(substring) |
|
start_token = None |
|
end_token = None |
|
|
|
for token in doc: |
|
if token.idx == start_index: |
|
start_token = token.i |
|
if token.idx + len(token) == end_index: |
|
end_token = token.i + 1 |
|
|
|
if start_token is None or end_token is None: |
|
print(f"Token boundaries not found for '{substring}' at index {start_index}") |
|
else: |
|
result.append({ |
|
"start": start_token, |
|
"end": end_token |
|
}) |
|
|
|
|
|
start_index = text.find(substring, end_index) |
|
|
|
if not result: |
|
print(f"Token boundaries not found for '{substring}'") |
|
|
|
return result |
|
|
|
|
|
def create_custom_entity_viz(data, full_text): |
|
nlp = spacy.blank("xx") |
|
doc = nlp(full_text) |
|
|
|
spans = [] |
|
colors = {} |
|
for node in data["nodes"]: |
|
|
|
entity_spans = find_token_indices(doc, node["id"], full_text) |
|
for dataentity in entity_spans: |
|
start = dataentity["start"] |
|
end = dataentity["end"] |
|
|
|
print("entity spans:", entity_spans) |
|
if start < len(doc) and end <= len(doc): |
|
span = Span(doc, start, end, label=node["type"]) |
|
|
|
|
|
spans.append(span) |
|
if node["type"] not in colors: |
|
colors[node["type"]] = get_random_light_color() |
|
|
|
for span in spans: |
|
print(f"Span: {span.text}, Label: {span.label_}") |
|
|
|
doc.set_ents(spans, default="unmodified") |
|
doc.spans["sc"] = spans |
|
|
|
options = { |
|
"colors": colors, |
|
"ents": list(colors.keys()), |
|
"style": "ent", |
|
"manual": True |
|
} |
|
|
|
html = displacy.render(doc, style="span", options=options) |
|
return html |
|
|
|
|
|
def create_graph(json_data): |
|
G = nx.Graph() |
|
|
|
for node in json_data['nodes']: |
|
G.add_node(node['id'], title=f"{node['type']}: {node['detailed_type']}") |
|
|
|
for edge in json_data['edges']: |
|
G.add_edge(edge['from'], edge['to'], title=edge['label'], label=edge['label']) |
|
|
|
nt = Network( |
|
width="720px", |
|
height="600px", |
|
directed=True, |
|
notebook=False, |
|
|
|
|
|
bgcolor="#FFFFFF", |
|
font_color="#111827" |
|
) |
|
nt.from_nx(G) |
|
nt.barnes_hut( |
|
gravity=-3000, |
|
central_gravity=0.3, |
|
spring_length=50, |
|
spring_strength=0.001, |
|
damping=0.09, |
|
overlap=0, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
html = nt.generate_html() |
|
|
|
html = html.replace("'", '"') |
|
|
|
|
|
return f"""<iframe style="width: 140%; height: 620px; margin: 0 auto;" name="result" |
|
allow="midi; geolocation; microphone; camera; display-capture; encrypted-media;" |
|
sandbox="allow-modals allow-forms allow-scripts allow-same-origin allow-popups |
|
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" |
|
allowpaymentrequest="" frameborder="0" srcdoc='{html}'></iframe>""" |
|
|
|
|
|
def process_and_visualize(text, model): |
|
if not text or not model: |
|
raise gr.Error("Text and model must be provided.") |
|
json_data = extract(text, model) |
|
|
|
print(json_data) |
|
entities_viz = create_custom_entity_viz(json_data, text) |
|
|
|
graph_html = create_graph(json_data) |
|
return graph_html, entities_viz, json_data |
|
|
|
|
|
|
|
with gr.Blocks(title="Phi-3 Mini 4k Instruct Graph (by Emergent Methods") as demo: |
|
gr.Markdown("# Phi-3 Mini 4k Instruct Graph (by Emergent Methods)") |
|
gr.Markdown("Extract a JSON graph from a text input and visualize it.") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
input_model = gr.Dropdown( |
|
MODEL_LIST, label="Model", |
|
|
|
) |
|
input_text = gr.TextArea(label="Text", info="The text to be extracted") |
|
|
|
examples = gr.Examples( |
|
examples=[ |
|
handle_text("""Legendary rock band Aerosmith has officially announced their retirement from touring after 54 years, citing |
|
lead singer Steven Tyler's unrecoverable vocal cord injury. |
|
The decision comes after months of unsuccessful treatment for Tyler's fractured larynx, |
|
which he suffered in September 2023."""), |
|
handle_text("""Pop star Justin Timberlake, 43, had his driver's license suspended by a New York judge during a virtual |
|
court hearing on August 2, 2024. The suspension follows Timberlake's arrest for driving while intoxicated (DWI) |
|
in Sag Harbor on June 18. Timberlake, who is currently on tour in Europe, |
|
pleaded not guilty to the charges."""), |
|
], |
|
inputs=input_text |
|
) |
|
|
|
submit_button = gr.Button("Extract and Visualize") |
|
|
|
with gr.Column(scale=1): |
|
output_entity_viz = gr.HTML(label="Entities Visualization", show_label=True) |
|
output_graph = gr.HTML(label="Graph Visualization", show_label=True) |
|
|
|
|
|
submit_button.click( |
|
fn=process_and_visualize, |
|
inputs=[input_text, input_model], |
|
outputs=[output_graph, output_entity_viz] |
|
) |
|
|
|
demo.launch(share=False) |