nicpopovic's picture
Update app.py
3c93886 verified
from datasets import load_dataset
import gradio as gr
import pandas as pd
import re
import matplotlib.pyplot as plt
import base64
import random
from pyvis.network import Network
# --- Data Loading ---
# Load the dataset from Hugging Face and convert to pandas DataFrame
ds = load_dataset("nicpopovic/vital_articles_synthetic_information_extraction")
df = ds['train'].to_pandas()
# --- Utility Functions ---
def get_palette():
"""Return a list of pastel colors for entity highlighting."""
return [
f'#{int(r*255):02x}{int(g*255):02x}{int(b*255):02x}'
for r, g, b in plt.get_cmap('Pastel1').colors
]
def assign_type_colors(entity_types):
"""
Assign a unique color to each entity type.
Args:
entity_types (list): List of entity type strings.
Returns:
dict: Mapping from entity type to color hex code.
"""
palette = get_palette()
return {etype: palette[i % len(palette)] for i, etype in enumerate(sorted(entity_types))}
def ent_to_span(text, type_colors):
"""
Convert <ent> tags in text to styled HTML spans for highlighting.
Args:
text (str): Annotated text with <ent> tags.
type_colors (dict): Mapping from entity type to color.
Returns:
str: HTML string with highlighted entities.
"""
def repl(match):
ent_id, ent_type, ent_text = match.group(1), match.group(2), match.group(3)
color = type_colors.get(ent_type, "#cccccc88")
return (
f"<span class='highlight spanhighlight' "
f"data-id='{ent_id}' data-type='{ent_type}' "
f"style='background-color: {color}'>"
f"{ent_text}<span class='small-text'>{ent_id}: {ent_type}</span></span>"
)
pattern = r'<ent id="(.*?)" type="(.*?)">(.*?)</ent>'
return re.sub(pattern, repl, text)
def vis_network_html(relations, entities, type_colors):
"""
Generate an interactive network graph as an HTML iframe.
Args:
relations (list): List of relation dicts.
entities (list): List of entity dicts.
type_colors (dict): Mapping from entity type to color.
Returns:
str: HTML iframe with the network graph.
"""
net = Network(height="350px", width="100%", directed=True, notebook=False)
# Add nodes for entities that participate in at least one relation
for ent in entities:
if not any(rel["subject"] == ent["id"] or rel["object"] == ent["id"] for rel in relations):
continue
color = type_colors.get(ent.get("type", ""), "#cccccc88")
net.add_node(
ent["id"],
label=ent["name"],
# Removed 'title' to disable hover
shape="box",
color=color
)
# Add edges for relations, with tooltip as description
for rel in relations:
net.add_edge(
rel["subject"],
rel["object"],
label=rel["predicate"],
arrows="to",
title=rel.get("description", "") # <-- Tooltip on hover
)
html = net.generate_html()
html_b64 = base64.b64encode(html.encode("utf-8")).decode("utf-8")
return f'<iframe src="data:text/html;base64,{html_b64}" width="100%" height="350px" frameborder="0"></iframe>'
# --- Gradio Display Functions ---
def show_annotated(index):
"""
Return HTML-annotated text for the given row index.
Args:
index (int): Row index in the DataFrame.
Returns:
str: HTML string with highlighted entities.
"""
row = df.iloc[int(index)]
type_colors = assign_type_colors(row['entity_types'])
html = ent_to_span(row['annotated_text'], type_colors)
return f'<div style="font-size:1.1em;line-height:1.6">{html}</div>'
def show_graph(index):
"""
Return an interactive graph HTML for the given row index.
Args:
index (int): Row index in the DataFrame.
Returns:
str: HTML iframe with the network graph.
"""
row = df.iloc[int(index)]
# Use entities from row, or extract from annotated_text if missing
entities = row['entities']
relations = row['relations']
entity_types = sorted({ent.get("type", "") for ent in entities})
type_colors = assign_type_colors(entity_types)
return vis_network_html(relations, entities, type_colors)
def get_raw_json(index):
"""
Return the raw JSON for the given row index.
Args:
index (int): Row index in the DataFrame.
Returns:
dict: Row data as a dictionary.
"""
return df.iloc[int(index)].to_dict()
# --- Gradio UI ---
# Custom CSS for styling the Gradio app
css = """
<style>
.prose { line-height: 200%; }
.highlight { display: inline; }
.highlight::after { background-color: var(data-color); }
.spanhighlight { padding: 2px 5px; border-radius: 5px; }
/* Removed .tooltip and .tooltip::after */
.small-text {
padding: 2px 5px;
background-color: white;
border-radius: 5px;
font-size: xx-small;
margin-left: 0.5em;
vertical-align: 0.2em;
font-weight: bold;
color: grey!important;
}
footer { display:none !important; }
.gradio-container { padding: 0!important; height:400px; }
</style>
"""
# Build the Gradio interface
with gr.Blocks(css=css, fill_height=True, fill_width=True) as demo:
#gr.Markdown("# Data Explorer - [Dataset on Hugging Face](https://huggingface.co/datasets/nicpopovic/vital_articles_synthetic_information_extraction)")
idx = gr.Number(
label="Row Index",
value=0,
precision=0,
minimum=0,
maximum=len(df) - 1
)
with gr.Tabs():
with gr.Tab("Annotated View"):
graph = gr.HTML()
out = gr.HTML()
with gr.Tab("Raw JSON"):
raw_json = gr.JSON()
# Update outputs when the row index changes
idx.change(fn=show_annotated, inputs=idx, outputs=out)
idx.change(fn=show_graph, inputs=idx, outputs=graph)
idx.change(fn=get_raw_json, inputs=idx, outputs=raw_json)
def on_load():
"""
On app load, pick a random row with at least one relation.
Returns:
tuple: (annotated HTML, graph HTML, raw JSON, row index)
"""
max_tries = 10
for _ in range(max_tries):
random_idx = random.randint(0, len(df) - 1)
relations = df.iloc[random_idx]['relations']
if len(relations) > 0:
break
else:
random_idx = 0 # fallback if all tries fail
return (
show_annotated(random_idx),
show_graph(random_idx),
get_raw_json(random_idx),
random_idx
)
# Set up the app to load a random example on startup
demo.load(
fn=on_load,
inputs=None,
outputs=[out, graph, raw_json, idx]
)
# Launch the Gradio app
demo.launch()