|
from datasets import load_dataset |
|
import gradio as gr |
|
import pandas as pd |
|
import re |
|
import matplotlib.pyplot as plt |
|
import base64 |
|
import random |
|
from pyvis.network import Network |
|
|
|
|
|
|
|
ds = load_dataset("nicpopovic/vital_articles_synthetic_information_extraction") |
|
df = ds['train'].to_pandas() |
|
|
|
|
|
|
|
def get_palette(): |
|
"""Return a list of pastel colors for entity highlighting.""" |
|
return [ |
|
f'#{int(r*255):02x}{int(g*255):02x}{int(b*255):02x}' |
|
for r, g, b in plt.get_cmap('Pastel1').colors |
|
] |
|
|
|
def assign_type_colors(entity_types): |
|
""" |
|
Assign a unique color to each entity type. |
|
Args: |
|
entity_types (list): List of entity type strings. |
|
Returns: |
|
dict: Mapping from entity type to color hex code. |
|
""" |
|
palette = get_palette() |
|
return {etype: palette[i % len(palette)] for i, etype in enumerate(sorted(entity_types))} |
|
|
|
def ent_to_span(text, type_colors): |
|
""" |
|
Convert <ent> tags in text to styled HTML spans for highlighting. |
|
Args: |
|
text (str): Annotated text with <ent> tags. |
|
type_colors (dict): Mapping from entity type to color. |
|
Returns: |
|
str: HTML string with highlighted entities. |
|
""" |
|
def repl(match): |
|
ent_id, ent_type, ent_text = match.group(1), match.group(2), match.group(3) |
|
color = type_colors.get(ent_type, "#cccccc88") |
|
return ( |
|
f"<span class='highlight spanhighlight' " |
|
f"data-id='{ent_id}' data-type='{ent_type}' " |
|
f"style='background-color: {color}'>" |
|
f"{ent_text}<span class='small-text'>{ent_id}: {ent_type}</span></span>" |
|
) |
|
pattern = r'<ent id="(.*?)" type="(.*?)">(.*?)</ent>' |
|
return re.sub(pattern, repl, text) |
|
|
|
def vis_network_html(relations, entities, type_colors): |
|
""" |
|
Generate an interactive network graph as an HTML iframe. |
|
Args: |
|
relations (list): List of relation dicts. |
|
entities (list): List of entity dicts. |
|
type_colors (dict): Mapping from entity type to color. |
|
Returns: |
|
str: HTML iframe with the network graph. |
|
""" |
|
net = Network(height="350px", width="100%", directed=True, notebook=False) |
|
|
|
for ent in entities: |
|
if not any(rel["subject"] == ent["id"] or rel["object"] == ent["id"] for rel in relations): |
|
continue |
|
color = type_colors.get(ent.get("type", ""), "#cccccc88") |
|
net.add_node( |
|
ent["id"], |
|
label=ent["name"], |
|
|
|
shape="box", |
|
color=color |
|
) |
|
|
|
for rel in relations: |
|
net.add_edge( |
|
rel["subject"], |
|
rel["object"], |
|
label=rel["predicate"], |
|
arrows="to", |
|
title=rel.get("description", "") |
|
) |
|
html = net.generate_html() |
|
html_b64 = base64.b64encode(html.encode("utf-8")).decode("utf-8") |
|
return f'<iframe src="data:text/html;base64,{html_b64}" width="100%" height="350px" frameborder="0"></iframe>' |
|
|
|
|
|
|
|
def show_annotated(index): |
|
""" |
|
Return HTML-annotated text for the given row index. |
|
Args: |
|
index (int): Row index in the DataFrame. |
|
Returns: |
|
str: HTML string with highlighted entities. |
|
""" |
|
row = df.iloc[int(index)] |
|
type_colors = assign_type_colors(row['entity_types']) |
|
html = ent_to_span(row['annotated_text'], type_colors) |
|
return f'<div style="font-size:1.1em;line-height:1.6">{html}</div>' |
|
|
|
def show_graph(index): |
|
""" |
|
Return an interactive graph HTML for the given row index. |
|
Args: |
|
index (int): Row index in the DataFrame. |
|
Returns: |
|
str: HTML iframe with the network graph. |
|
""" |
|
row = df.iloc[int(index)] |
|
|
|
entities = row['entities'] |
|
|
|
relations = row['relations'] |
|
entity_types = sorted({ent.get("type", "") for ent in entities}) |
|
type_colors = assign_type_colors(entity_types) |
|
return vis_network_html(relations, entities, type_colors) |
|
|
|
def get_raw_json(index): |
|
""" |
|
Return the raw JSON for the given row index. |
|
Args: |
|
index (int): Row index in the DataFrame. |
|
Returns: |
|
dict: Row data as a dictionary. |
|
""" |
|
return df.iloc[int(index)].to_dict() |
|
|
|
|
|
|
|
|
|
css = """ |
|
<style> |
|
.prose { line-height: 200%; } |
|
.highlight { display: inline; } |
|
.highlight::after { background-color: var(data-color); } |
|
.spanhighlight { padding: 2px 5px; border-radius: 5px; } |
|
/* Removed .tooltip and .tooltip::after */ |
|
.small-text { |
|
padding: 2px 5px; |
|
background-color: white; |
|
border-radius: 5px; |
|
font-size: xx-small; |
|
margin-left: 0.5em; |
|
vertical-align: 0.2em; |
|
font-weight: bold; |
|
color: grey!important; |
|
} |
|
footer { display:none !important; } |
|
.gradio-container { padding: 0!important; height:400px; } |
|
</style> |
|
""" |
|
|
|
|
|
with gr.Blocks(css=css, fill_height=True, fill_width=True) as demo: |
|
|
|
idx = gr.Number( |
|
label="Row Index", |
|
value=0, |
|
precision=0, |
|
minimum=0, |
|
maximum=len(df) - 1 |
|
) |
|
|
|
with gr.Tabs(): |
|
with gr.Tab("Annotated View"): |
|
graph = gr.HTML() |
|
out = gr.HTML() |
|
with gr.Tab("Raw JSON"): |
|
raw_json = gr.JSON() |
|
|
|
|
|
idx.change(fn=show_annotated, inputs=idx, outputs=out) |
|
idx.change(fn=show_graph, inputs=idx, outputs=graph) |
|
idx.change(fn=get_raw_json, inputs=idx, outputs=raw_json) |
|
|
|
def on_load(): |
|
""" |
|
On app load, pick a random row with at least one relation. |
|
Returns: |
|
tuple: (annotated HTML, graph HTML, raw JSON, row index) |
|
""" |
|
max_tries = 10 |
|
for _ in range(max_tries): |
|
random_idx = random.randint(0, len(df) - 1) |
|
relations = df.iloc[random_idx]['relations'] |
|
if len(relations) > 0: |
|
break |
|
else: |
|
random_idx = 0 |
|
return ( |
|
show_annotated(random_idx), |
|
show_graph(random_idx), |
|
get_raw_json(random_idx), |
|
random_idx |
|
) |
|
|
|
|
|
demo.load( |
|
fn=on_load, |
|
inputs=None, |
|
outputs=[out, graph, raw_json, idx] |
|
) |
|
|
|
|
|
demo.launch() |