DLRNA-BERTa / app.py
IlPakoZ's picture
Update app.py
a81fdbb verified
import gradio as gr
import torch
import numpy as np
from transformers import AutoModel, AutoTokenizer, AutoConfig, RobertaModel
from modeling_dlmberta import InteractionModelATTNForRegression, StdScaler
from configuration_dlmberta import InteractionModelATTNConfig
from chemberta import ChembertaTokenizer
import json
import os
from pathlib import Path
import logging
# Import visualization functions
from analysis import plot_crossattention_weights, plot_presum
from PIL import Image, ImageDraw, ImageFont
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def create_placeholder_image(width=600, height=400, text="No visualization available", bg_color=(0, 0, 0, 0)):
"""
Create a transparent placeholder image with text
Args:
width (int): Image width
height (int): Image height
text (str): Text to display
bg_color (tuple): Background color (R, G, B, A) - (0,0,0,0) for transparent
Returns:
PIL.Image: Transparent placeholder image
"""
# Create image with transparent background
img = Image.new('RGBA', (width, height), bg_color)
draw = ImageDraw.Draw(img)
# Try to use a default font, fallback to default if not available
try:
font = ImageFont.truetype("arial.ttf", 16)
except:
try:
font = ImageFont.load_default()
except:
font = None
# Get text size and position for centering
if font:
bbox = draw.textbbox((0, 0), text, font=font)
text_width = bbox[2] - bbox[0]
text_height = bbox[3] - bbox[1]
else:
# Rough estimation if no font available
text_width = len(text) * 8
text_height = 16
x = (width - text_width) // 2
y = (height - text_height) // 2
# Draw text in gray
draw.text((x, y), text, fill=(128, 128, 128, 255), font=font)
return img
class DrugTargetInteractionApp:
def __init__(self):
self.model = None
self.target_tokenizer = None
self.drug_tokenizer = None
self.scaler = None
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model(self, model_path="./"):
"""Load the pre-trained model and tokenizers"""
try:
# Load configuration
config = InteractionModelATTNConfig.from_pretrained(model_path)
# Load drug encoder (ChemBERTa)
drug_encoder_config = AutoConfig.from_pretrained("DeepChem/ChemBERTa-77M-MTR")
drug_encoder_config.pooler = None
drug_encoder = RobertaModel(config=drug_encoder_config, add_pooling_layer=False)
# Load target encoder
target_encoder = AutoModel.from_pretrained("IlPakoZ/RNA-BERTa9700")
# Load scaler if exists
scaler_path = os.path.join(model_path, "scaler.config")
scaler = None
if os.path.exists(scaler_path):
scaler = StdScaler()
scaler.load(model_path)
self.model = InteractionModelATTNForRegression.from_pretrained(
model_path,
config=config,
target_encoder=target_encoder,
drug_encoder=drug_encoder,
scaler=scaler
)
self.model.to(self.device)
self.model.eval()
# Load tokenizers
self.target_tokenizer = AutoTokenizer.from_pretrained(
os.path.join(model_path, "target_tokenizer")
)
# Load drug tokenizer (ChemBERTa)
vocab_file = os.path.join(model_path, "drug_tokenizer", "vocab.json")
self.drug_tokenizer = ChembertaTokenizer(vocab_file)
logger.info("Model and tokenizers loaded successfully!")
return True
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
return False
def predict_interaction(self, target_sequence, drug_smiles, max_length=512):
"""Predict drug-target interaction"""
if self.model is None:
return "Error: Model not loaded. Please load a model first."
try:
# Tokenize inputs
target_inputs = self.target_tokenizer(
target_sequence,
padding="max_length",
truncation=True,
max_length=512,
return_tensors="pt"
).to(self.device)
drug_inputs = self.drug_tokenizer(
drug_smiles,
padding="max_length",
truncation=True,
max_length=512,
return_tensors="pt"
).to(self.device)
# Make prediction
self.model.INTERPR_DISABLE_MODE()
with torch.no_grad():
prediction = self.model(target_inputs, drug_inputs)
# Unscale if scaler exists
if self.model.scaler is not None:
prediction = self.model.unscale(prediction)
prediction_value = prediction.cpu().numpy()[0][0]
return f"Predicted Binding Affinity: {prediction_value:.4f}"
except Exception as e:
logger.error(f"Prediction error: {str(e)}")
return f"Error during prediction: {str(e)}"
def visualize_interaction(self, target_sequence, drug_smiles):
"""
Generate visualization images for drug-target interaction
Args:
target_sequence (str): RNA sequence
drug_smiles (str): Drug SMILES notation
Returns:
tuple: (cross_attention_image, raw_contribution_image, normalized_contribution_image, status_message)
"""
if self.model is None:
return None, None, None, "Error: Model not loaded. Please load a model first."
try:
# Tokenize inputs
target_inputs = self.target_tokenizer(
target_sequence,
padding="max_length",
truncation=True,
max_length=512,
return_tensors="pt"
).to(self.device)
drug_inputs = self.drug_tokenizer(
drug_smiles,
padding="max_length",
truncation=True,
max_length=512,
return_tensors="pt"
).to(self.device)
# Enable interpretation mode
self.model.INTERPR_ENABLE_MODE()
# Make prediction and extract visualization data
with torch.no_grad():
prediction = self.model(target_inputs, drug_inputs)
# Unscale if scaler exists
if self.model.scaler is not None:
prediction = self.model.unscale(prediction)
prediction_value = prediction.cpu().numpy()[0][0]
# Extract data needed for visualizations
presum_values = self.model.model.presum_layer # Shape: (1, seq_len)
cross_attention_weights = self.model.model.crossattention_weights # Shape: (batch, heads, seq_len, seq_len)
# Get model parameters for scaling
w = self.model.model.w.squeeze(1)
b = self.model.model.b
scaler = self.model.model.scaler
logger.info(f"Target inputs shape: {target_inputs['input_ids'].shape}")
logger.info(f"Drug inputs shape: {drug_inputs['input_ids'].shape}")
# Generate visualizations
try:
# 1. Cross-attention heatmap
cross_attention_img = None
logger.info(f"Cross-attention weights type: {type(cross_attention_weights)}")
if cross_attention_weights is not None:
logger.info(f"Cross-attention weights shape: {cross_attention_weights.shape if hasattr(cross_attention_weights, 'shape') else 'No shape attr'}")
try:
cross_attn_matrix = cross_attention_weights[0, 0]
if cross_attn_matrix is not None:
logger.info(f"Extracted cross-attention matrix shape: {cross_attn_matrix.shape}")
logger.info(f"Target attention mask shape: {target_inputs['attention_mask'].shape}")
logger.info(f"Drug attention mask shape: {drug_inputs['attention_mask'].shape}")
cross_attention_img = plot_crossattention_weights(
target_inputs["attention_mask"][0],
drug_inputs["attention_mask"][0],
target_inputs,
drug_inputs,
cross_attn_matrix,
self.target_tokenizer,
self.drug_tokenizer
)
else:
logger.warning("Could not extract valid cross-attention matrix")
except (IndexError, TypeError, AttributeError) as e:
logger.warning(f"Error extracting cross-attention matrix: {str(e)}")
cross_attn_matrix = None
else:
logger.warning("Cross-attention weights are None")
except Exception as e:
logger.error(f"Cross-attention visualization error: {str(e)}")
cross_attention_img = None
try:
# 2. Normalized contribution visualization (only if pKd > 0)
normalized_img = None
if presum_values is not None:
normalized_img = plot_presum(
target_inputs,
presum_values.detach(), # Detach the tensor
scaler,
w.detach(), # Detach the tensor
b.detach(), # Detach the tensor
self.target_tokenizer,
raw_affinities=False
)
else:
if prediction_value <= 0:
logger.info("Skipping normalized affinities visualization as pKd <= 0")
if presum_values is None:
logger.warning("Cannot generate raw visualization: presum values are None")
except Exception as e:
logger.error(f"Normalized contribution visualization error: {str(e)}")
normalized_img = None
try:
# 3. Raw contribution visualization (always generate)
raw_img = None
if prediction_value > 0 and presum_values is not None:
raw_img = plot_presum(
target_inputs,
presum_values.detach(), # Detach the tensor
scaler,
w.detach(), # Detach the tensor
b.detach(), # Detach the tensor
self.target_tokenizer,
raw_affinities=True
)
else:
logger.warning("Presum values are None")
except Exception as e:
logger.error(f"Raw contribution visualization error: {str(e)}")
raw_img = None
# Disable interpretation mode after use
self.model.INTERPR_DISABLE_MODE()
# Create placeholder images if generation failed
if cross_attention_img is None:
cross_attention_img = create_placeholder_image(
text="Cross-Attention Heatmap\nFailed to generate"
)
if normalized_img is None:
normalized_img = create_placeholder_image(
text="Normalized Contribution\nFailed to generate"
)
if raw_img is None and prediction_value > 0:
raw_img = create_placeholder_image(
text="Raw Contribution\nFailed to generate"
)
elif raw_img is None:
raw_img = create_placeholder_image(
text="Raw Contribution\nSkipped (pKd ≤ 0)"
)
status_msg = f"Predicted Binding Affinity: {prediction_value:.4f}"
if prediction_value <= 0:
status_msg += " (Raw contribution visualization skipped due to non-positive pKd)"
if cross_attention_weights is None:
status_msg += " (Cross-attention visualization failed: weights not available)"
return cross_attention_img, raw_img, normalized_img, status_msg
except Exception as e:
logger.error(f"Visualization error: {str(e)}")
# Make sure to disable interpretation mode even if there's an error
try:
self.model.INTERPR_DISABLE_MODE()
except:
pass
return None, None, None, f"Error during visualization: {str(e)}"
# Initialize the app
app = DrugTargetInteractionApp()
def predict_wrapper(target_seq, drug_smiles):
"""Wrapper function for Gradio interface"""
if not target_seq.strip() or not drug_smiles.strip():
return "Please provide both target sequence and drug SMILES."
return app.predict_interaction(target_seq, drug_smiles)
def visualize_wrapper(target_seq, drug_smiles):
"""Wrapper function for visualization"""
if not target_seq.strip() or not drug_smiles.strip():
return None, None, None, "Please provide both target sequence and drug SMILES."
return app.visualize_interaction(target_seq, drug_smiles)
def load_model_wrapper(model_path):
"""Wrapper function to load model"""
if app.load_model(model_path):
return "Model loaded successfully!"
else:
return "Failed to load model. Check the path and files."
# Create Gradio interface
with gr.Blocks(title="Drug-Target Interaction Predictor", theme=gr.themes.Soft()) as demo:
gr.HTML("""
<div style="text-align: center; margin-bottom: 30px;">
<h1 style="color: #2E86AB; font-size: 2.5em; margin-bottom: 10px;">
🧬 Drug-Target Interaction Predictor
</h1>
<p style="font-size: 1.2em; color: #666;">
Predict binding affinity between drugs and target RNA sequences using deep learning
</p>
</div>
""")
# Create state variables to share images between tabs
viz_state1 = gr.State()
viz_state2 = gr.State()
viz_state3 = gr.State()
with gr.Tab("🔮 Prediction & Analysis"):
with gr.Row():
with gr.Column(scale=1):
target_input = gr.Textbox(
label="Target RNA Sequence",
placeholder="Enter RNA sequence (e.g., AUGCUAGCUAGUACGUA...)",
lines=4,
max_lines=6
)
drug_input = gr.Textbox(
label="Drug SMILES",
placeholder="Enter SMILES notation (e.g., CC(C)CC1=CC=C(C=C1)C(C)C(=O)O)",
lines=2
)
with gr.Row():
predict_btn = gr.Button("🚀 Predict Interaction", variant="primary", size="lg")
visualize_btn = gr.Button("📊 Generate Visualizations", variant="secondary", size="lg")
with gr.Column(scale=1):
prediction_output = gr.Textbox(
label="Prediction Result",
interactive=False,
lines=4
)
# Example inputs
gr.HTML("<h3 style='margin-top: 20px; color: #2E86AB;'>📚 Example Inputs:</h3>")
examples = gr.Examples(
examples=[
[
"AUGCUAGCUAGUACGUAUAUCUGCACUGC",
"CC(C)CC1=CC=C(C=C1)C(C)C(=O)O"
],
[
"AUGCGAUCGACGUACGUUAGCCGUAGCGUAGCUAGUGUAGCUAGUAGCU",
"C1=CC=C(C=C1)NC(=O)C2=CC=CC=N2"
]
],
inputs=[target_input, drug_input],
outputs=prediction_output,
fn=predict_wrapper,
cache_examples=False
)
# Button click events
predict_btn.click(
fn=predict_wrapper,
inputs=[target_input, drug_input],
outputs=prediction_output
)
def visualize_and_update(target_seq, drug_smiles):
"""Generate visualizations and update both status and state"""
img1, img2, img3, status = visualize_wrapper(target_seq, drug_smiles)
# Combine prediction result with visualization status
combined_status = status + "\n\nVisualization analysis complete. Please navigate to the Visualizations tab to view the generated images."
return img1, img2, img3, combined_status
visualize_btn.click(
fn=visualize_and_update,
inputs=[target_input, drug_input],
outputs=[viz_state1, viz_state2, viz_state3, prediction_output]
)
with gr.Tab("📊 Visualizations"):
gr.HTML("""
<div style="text-align: center; margin-bottom: 20px;">
<h2 style="color: #2E86AB;">🔬 Interaction Analysis & Visualizations</h2>
<p style="font-size: 1.1em; color: #666;">
Generated visualizations will appear here after clicking "Generate Visualizations" in the Prediction tab
</p>
</div>
""")
# Visualization outputs - Large and vertically aligned
viz_image1 = gr.Image(
label="Cross-Attention Heatmap",
type="pil",
interactive=False,
container=True,
height=500,
value=create_placeholder_image(text="Cross-Attention Heatmap\n(Generate visualizations in the Prediction tab)")
)
viz_image2 = gr.Image(
label="Raw pKd Contribution Visualization",
type="pil",
interactive=False,
container=True,
height=500,
value=create_placeholder_image(text="Raw pKd Contribution\n(Generate visualizations in the Prediction tab)")
)
viz_image3 = gr.Image(
label="Normalized pKd Contribution Visualization",
type="pil",
interactive=False,
container=True,
height=500,
value=create_placeholder_image(text="Normalized pKd Contribution\n(Generate visualizations in the Prediction tab)")
)
# Update visualization images when state changes
viz_state1.change(
fn=lambda x: x,
inputs=viz_state1,
outputs=viz_image1
)
viz_state2.change(
fn=lambda x: x,
inputs=viz_state2,
outputs=viz_image2
)
viz_state3.change(
fn=lambda x: x,
inputs=viz_state3,
outputs=viz_image3
)
with gr.Tab("⚙️ Model Settings"):
gr.HTML("<h3 style='color: #2E86AB;'>Model Configuration</h3>")
model_path_input = gr.Textbox(
label="Model Path",
value="./",
placeholder="Path to model directory"
)
load_model_btn = gr.Button("📥 Load Model", variant="secondary")
model_status = gr.Textbox(
label="Status",
interactive=False,
value="No model loaded"
)
load_model_btn.click(
fn=load_model_wrapper,
inputs=model_path_input,
outputs=model_status
)
with gr.Tab("📊 Dataset"):
gr.Markdown("""
## Training and Test Datasets
### Fine-tuning Dataset (Training)
The model was trained on a dataset comprising **1,439 RNA–drug interaction pairs**, including:
- **759 unique compounds** (SMILES representations)
- **294 unique RNA sequences**
- Dissociation constants (pKd values) for binding affinity prediction
**RNA Sequence Distribution by Type:**
| RNA Sequence Type | Number of Interactions |
|-------------------|------------------------|
| Aptamers | 520 |
| Ribosomal | 295 |
| Viral RNAs | 281 |
| miRNAs | 146 |
| Riboswitches | 100 |
| Repeats | 97 |
| **Total** | **1,439** |
### External Evaluation Dataset (Test)
Model validation was performed using external ROBIN classification datasets containing **5,534 RNA–drug pairs**:
- **2,991 positive interactions**
- **2,538 negative interactions**
**Test Dataset Composition:**
- **1,617 aptamer pairs** (5 unique RNA sequences)
- **1,828 viral RNA pairs** (6 unique RNA sequences)
- **1,459 riboswitch pairs** (5 unique RNA sequences)
- **630 miRNA pairs** (3 unique RNA sequences)
### Dataset Downloads
- [Training Dataset Download](https://huggingface.co/spaces/IlPakoZ/DLRNA-BERTa/resolve/main/datasets/training_data.csv?download=true)
- [Test Dataset Download](https://huggingface.co/spaces/IlPakoZ/DLRNA-BERTa/resolve/main/datasets/test_data.csv?download=true)
### Citation
Original datasets published by:
**Krishnan et al.** - Available on the RSAPred website in PDF format.
*Reference:*
```bibtex
@article{krishnan2024reliable,
title={Reliable method for predicting the binding affinity of RNA-small molecule interactions using machine learning},
author={Krishnan, Sowmya R and Roy, Arijit and Gromiha, M Michael},
journal={Briefings in Bioinformatics},
volume={25},
number={2},
pages={bbae002},
year={2024},
publisher={Oxford University Press}
}
```
""")
with gr.Tab("ℹ️ About"):
gr.Markdown("""
## About this application
This application implements DLRNA-BERTa, a Dual Langauge RoBERTa Transformer model for predicting drug to RNA target interactions. The model architecture includes:
- **Target encoder**: Processes RNA sequences using RNA-BERTa
- **Drug encoder**: Processes molecular SMILES notation using ChemBERTa
- **Cross-attention mechanism**: Captures interactions between drugs and targets
- **Regression head**: Predicts binding affinity scores (pKd values)
### Input requirements:
- **Target sequence**: RNA sequence of the target (nucleotide sequences: A, U, G, C)
- **Drug SMILES**: Simplified Molecular Input Line Entry System notation
### Model features:
- Cross-attention for drug-target interaction modeling
- Dropout for regularization
- Layer normalization for stable training
- Interpretability mode for contribution and attention visualization
### Usage tips:
1. Load a trained model using the Model Settings tab (optional)
2. Enter a RNA sequence and drug SMILES in the Prediction & Analysis tab
3. Click "Predict Interaction" for binding affinity prediction only
4. Click "Generate Visualizations" to create detailed interaction analysis - results will appear in the Visualizations tab
For best results, ensure your input sequences are properly formatted and within reasonable length limits (max 512 tokens).
### Visualization features:
- **Cross-attention heatmap**: Shows cross-attention weights between drug and target tokens
- **Unnormalized pKd contribution**: Shows unnormalized signed contributions from each target token (only when pKd > 0)
- **Normalized pKd contribution**: Shows normalized non-negative contributions from each target token
### Performance metrics:
- Training on diverse drug-target interaction datasets
- Evaluated using RMSE, Pearson correlation, and Concordance Index
- Optimized for both predictive accuracy and interpretability
### GitHub repository:
- The full model GitHub repository can be found here: https://github.com/IlPakoZ/dlrnaberta-dti-prediction
### Contribution:
- Special thanks to Umut Onur Özcan for help in developing this space:)
### Contact:
- Ziaurrehman Tanoli (ziaurrehman.tanoli@helsinki.fi)
Principal investigator at Institute for Molecular Medicine Finland
HiLIFE, University of Helsinki, Finland.
""")
# Launch the app
if __name__ == "__main__":
# Try to load model on startup
if os.path.exists("./config.json"):
app.load_model("./")
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True
)