Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import numpy as np | |
from transformers import AutoModel, AutoTokenizer, AutoConfig, RobertaModel | |
from modeling_dlmberta import InteractionModelATTNForRegression, StdScaler | |
from configuration_dlmberta import InteractionModelATTNConfig | |
from chemberta import ChembertaTokenizer | |
import json | |
import os | |
from pathlib import Path | |
import logging | |
# Import visualization functions | |
from analysis import plot_crossattention_weights, plot_presum | |
from PIL import Image, ImageDraw, ImageFont | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
def create_placeholder_image(width=600, height=400, text="No visualization available", bg_color=(0, 0, 0, 0)): | |
""" | |
Create a transparent placeholder image with text | |
Args: | |
width (int): Image width | |
height (int): Image height | |
text (str): Text to display | |
bg_color (tuple): Background color (R, G, B, A) - (0,0,0,0) for transparent | |
Returns: | |
PIL.Image: Transparent placeholder image | |
""" | |
# Create image with transparent background | |
img = Image.new('RGBA', (width, height), bg_color) | |
draw = ImageDraw.Draw(img) | |
# Try to use a default font, fallback to default if not available | |
try: | |
font = ImageFont.truetype("arial.ttf", 16) | |
except: | |
try: | |
font = ImageFont.load_default() | |
except: | |
font = None | |
# Get text size and position for centering | |
if font: | |
bbox = draw.textbbox((0, 0), text, font=font) | |
text_width = bbox[2] - bbox[0] | |
text_height = bbox[3] - bbox[1] | |
else: | |
# Rough estimation if no font available | |
text_width = len(text) * 8 | |
text_height = 16 | |
x = (width - text_width) // 2 | |
y = (height - text_height) // 2 | |
# Draw text in gray | |
draw.text((x, y), text, fill=(128, 128, 128, 255), font=font) | |
return img | |
class DrugTargetInteractionApp: | |
def __init__(self): | |
self.model = None | |
self.target_tokenizer = None | |
self.drug_tokenizer = None | |
self.scaler = None | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def load_model(self, model_path="./"): | |
"""Load the pre-trained model and tokenizers""" | |
try: | |
# Load configuration | |
config = InteractionModelATTNConfig.from_pretrained(model_path) | |
# Load drug encoder (ChemBERTa) | |
drug_encoder_config = AutoConfig.from_pretrained("DeepChem/ChemBERTa-77M-MTR") | |
drug_encoder_config.pooler = None | |
drug_encoder = RobertaModel(config=drug_encoder_config, add_pooling_layer=False) | |
# Load target encoder | |
target_encoder = AutoModel.from_pretrained("IlPakoZ/RNA-BERTa9700") | |
# Load scaler if exists | |
scaler_path = os.path.join(model_path, "scaler.config") | |
scaler = None | |
if os.path.exists(scaler_path): | |
scaler = StdScaler() | |
scaler.load(model_path) | |
self.model = InteractionModelATTNForRegression.from_pretrained( | |
model_path, | |
config=config, | |
target_encoder=target_encoder, | |
drug_encoder=drug_encoder, | |
scaler=scaler | |
) | |
self.model.to(self.device) | |
self.model.eval() | |
# Load tokenizers | |
self.target_tokenizer = AutoTokenizer.from_pretrained( | |
os.path.join(model_path, "target_tokenizer") | |
) | |
# Load drug tokenizer (ChemBERTa) | |
vocab_file = os.path.join(model_path, "drug_tokenizer", "vocab.json") | |
self.drug_tokenizer = ChembertaTokenizer(vocab_file) | |
logger.info("Model and tokenizers loaded successfully!") | |
return True | |
except Exception as e: | |
logger.error(f"Error loading model: {str(e)}") | |
return False | |
def predict_interaction(self, target_sequence, drug_smiles, max_length=512): | |
"""Predict drug-target interaction""" | |
if self.model is None: | |
return "Error: Model not loaded. Please load a model first." | |
try: | |
# Tokenize inputs | |
target_inputs = self.target_tokenizer( | |
target_sequence, | |
padding="max_length", | |
truncation=True, | |
max_length=512, | |
return_tensors="pt" | |
).to(self.device) | |
drug_inputs = self.drug_tokenizer( | |
drug_smiles, | |
padding="max_length", | |
truncation=True, | |
max_length=512, | |
return_tensors="pt" | |
).to(self.device) | |
# Make prediction | |
self.model.INTERPR_DISABLE_MODE() | |
with torch.no_grad(): | |
prediction = self.model(target_inputs, drug_inputs) | |
# Unscale if scaler exists | |
if self.model.scaler is not None: | |
prediction = self.model.unscale(prediction) | |
prediction_value = prediction.cpu().numpy()[0][0] | |
return f"Predicted Binding Affinity: {prediction_value:.4f}" | |
except Exception as e: | |
logger.error(f"Prediction error: {str(e)}") | |
return f"Error during prediction: {str(e)}" | |
def visualize_interaction(self, target_sequence, drug_smiles): | |
""" | |
Generate visualization images for drug-target interaction | |
Args: | |
target_sequence (str): RNA sequence | |
drug_smiles (str): Drug SMILES notation | |
Returns: | |
tuple: (cross_attention_image, raw_contribution_image, normalized_contribution_image, status_message) | |
""" | |
if self.model is None: | |
return None, None, None, "Error: Model not loaded. Please load a model first." | |
try: | |
# Tokenize inputs | |
target_inputs = self.target_tokenizer( | |
target_sequence, | |
padding="max_length", | |
truncation=True, | |
max_length=512, | |
return_tensors="pt" | |
).to(self.device) | |
drug_inputs = self.drug_tokenizer( | |
drug_smiles, | |
padding="max_length", | |
truncation=True, | |
max_length=512, | |
return_tensors="pt" | |
).to(self.device) | |
# Enable interpretation mode | |
self.model.INTERPR_ENABLE_MODE() | |
# Make prediction and extract visualization data | |
with torch.no_grad(): | |
prediction = self.model(target_inputs, drug_inputs) | |
# Unscale if scaler exists | |
if self.model.scaler is not None: | |
prediction = self.model.unscale(prediction) | |
prediction_value = prediction.cpu().numpy()[0][0] | |
# Extract data needed for visualizations | |
presum_values = self.model.model.presum_layer # Shape: (1, seq_len) | |
cross_attention_weights = self.model.model.crossattention_weights # Shape: (batch, heads, seq_len, seq_len) | |
# Get model parameters for scaling | |
w = self.model.model.w.squeeze(1) | |
b = self.model.model.b | |
scaler = self.model.model.scaler | |
logger.info(f"Target inputs shape: {target_inputs['input_ids'].shape}") | |
logger.info(f"Drug inputs shape: {drug_inputs['input_ids'].shape}") | |
# Generate visualizations | |
try: | |
# 1. Cross-attention heatmap | |
cross_attention_img = None | |
logger.info(f"Cross-attention weights type: {type(cross_attention_weights)}") | |
if cross_attention_weights is not None: | |
logger.info(f"Cross-attention weights shape: {cross_attention_weights.shape if hasattr(cross_attention_weights, 'shape') else 'No shape attr'}") | |
try: | |
cross_attn_matrix = cross_attention_weights[0, 0] | |
if cross_attn_matrix is not None: | |
logger.info(f"Extracted cross-attention matrix shape: {cross_attn_matrix.shape}") | |
logger.info(f"Target attention mask shape: {target_inputs['attention_mask'].shape}") | |
logger.info(f"Drug attention mask shape: {drug_inputs['attention_mask'].shape}") | |
cross_attention_img = plot_crossattention_weights( | |
target_inputs["attention_mask"][0], | |
drug_inputs["attention_mask"][0], | |
target_inputs, | |
drug_inputs, | |
cross_attn_matrix, | |
self.target_tokenizer, | |
self.drug_tokenizer | |
) | |
else: | |
logger.warning("Could not extract valid cross-attention matrix") | |
except (IndexError, TypeError, AttributeError) as e: | |
logger.warning(f"Error extracting cross-attention matrix: {str(e)}") | |
cross_attn_matrix = None | |
else: | |
logger.warning("Cross-attention weights are None") | |
except Exception as e: | |
logger.error(f"Cross-attention visualization error: {str(e)}") | |
cross_attention_img = None | |
try: | |
# 2. Normalized contribution visualization (only if pKd > 0) | |
normalized_img = None | |
if presum_values is not None: | |
normalized_img = plot_presum( | |
target_inputs, | |
presum_values.detach(), # Detach the tensor | |
scaler, | |
w.detach(), # Detach the tensor | |
b.detach(), # Detach the tensor | |
self.target_tokenizer, | |
raw_affinities=False | |
) | |
else: | |
if prediction_value <= 0: | |
logger.info("Skipping normalized affinities visualization as pKd <= 0") | |
if presum_values is None: | |
logger.warning("Cannot generate raw visualization: presum values are None") | |
except Exception as e: | |
logger.error(f"Normalized contribution visualization error: {str(e)}") | |
normalized_img = None | |
try: | |
# 3. Raw contribution visualization (always generate) | |
raw_img = None | |
if prediction_value > 0 and presum_values is not None: | |
raw_img = plot_presum( | |
target_inputs, | |
presum_values.detach(), # Detach the tensor | |
scaler, | |
w.detach(), # Detach the tensor | |
b.detach(), # Detach the tensor | |
self.target_tokenizer, | |
raw_affinities=True | |
) | |
else: | |
logger.warning("Presum values are None") | |
except Exception as e: | |
logger.error(f"Raw contribution visualization error: {str(e)}") | |
raw_img = None | |
# Disable interpretation mode after use | |
self.model.INTERPR_DISABLE_MODE() | |
# Create placeholder images if generation failed | |
if cross_attention_img is None: | |
cross_attention_img = create_placeholder_image( | |
text="Cross-Attention Heatmap\nFailed to generate" | |
) | |
if normalized_img is None: | |
normalized_img = create_placeholder_image( | |
text="Normalized Contribution\nFailed to generate" | |
) | |
if raw_img is None and prediction_value > 0: | |
raw_img = create_placeholder_image( | |
text="Raw Contribution\nFailed to generate" | |
) | |
elif raw_img is None: | |
raw_img = create_placeholder_image( | |
text="Raw Contribution\nSkipped (pKd ≤ 0)" | |
) | |
status_msg = f"Predicted Binding Affinity: {prediction_value:.4f}" | |
if prediction_value <= 0: | |
status_msg += " (Raw contribution visualization skipped due to non-positive pKd)" | |
if cross_attention_weights is None: | |
status_msg += " (Cross-attention visualization failed: weights not available)" | |
return cross_attention_img, raw_img, normalized_img, status_msg | |
except Exception as e: | |
logger.error(f"Visualization error: {str(e)}") | |
# Make sure to disable interpretation mode even if there's an error | |
try: | |
self.model.INTERPR_DISABLE_MODE() | |
except: | |
pass | |
return None, None, None, f"Error during visualization: {str(e)}" | |
# Initialize the app | |
app = DrugTargetInteractionApp() | |
def predict_wrapper(target_seq, drug_smiles): | |
"""Wrapper function for Gradio interface""" | |
if not target_seq.strip() or not drug_smiles.strip(): | |
return "Please provide both target sequence and drug SMILES." | |
return app.predict_interaction(target_seq, drug_smiles) | |
def visualize_wrapper(target_seq, drug_smiles): | |
"""Wrapper function for visualization""" | |
if not target_seq.strip() or not drug_smiles.strip(): | |
return None, None, None, "Please provide both target sequence and drug SMILES." | |
return app.visualize_interaction(target_seq, drug_smiles) | |
def load_model_wrapper(model_path): | |
"""Wrapper function to load model""" | |
if app.load_model(model_path): | |
return "Model loaded successfully!" | |
else: | |
return "Failed to load model. Check the path and files." | |
# Create Gradio interface | |
with gr.Blocks(title="Drug-Target Interaction Predictor", theme=gr.themes.Soft()) as demo: | |
gr.HTML(""" | |
<div style="text-align: center; margin-bottom: 30px;"> | |
<h1 style="color: #2E86AB; font-size: 2.5em; margin-bottom: 10px;"> | |
🧬 Drug-Target Interaction Predictor | |
</h1> | |
<p style="font-size: 1.2em; color: #666;"> | |
Predict binding affinity between drugs and target RNA sequences using deep learning | |
</p> | |
</div> | |
""") | |
# Create state variables to share images between tabs | |
viz_state1 = gr.State() | |
viz_state2 = gr.State() | |
viz_state3 = gr.State() | |
with gr.Tab("🔮 Prediction & Analysis"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
target_input = gr.Textbox( | |
label="Target RNA Sequence", | |
placeholder="Enter RNA sequence (e.g., AUGCUAGCUAGUACGUA...)", | |
lines=4, | |
max_lines=6 | |
) | |
drug_input = gr.Textbox( | |
label="Drug SMILES", | |
placeholder="Enter SMILES notation (e.g., CC(C)CC1=CC=C(C=C1)C(C)C(=O)O)", | |
lines=2 | |
) | |
with gr.Row(): | |
predict_btn = gr.Button("🚀 Predict Interaction", variant="primary", size="lg") | |
visualize_btn = gr.Button("📊 Generate Visualizations", variant="secondary", size="lg") | |
with gr.Column(scale=1): | |
prediction_output = gr.Textbox( | |
label="Prediction Result", | |
interactive=False, | |
lines=4 | |
) | |
# Example inputs | |
gr.HTML("<h3 style='margin-top: 20px; color: #2E86AB;'>📚 Example Inputs:</h3>") | |
examples = gr.Examples( | |
examples=[ | |
[ | |
"AUGCUAGCUAGUACGUAUAUCUGCACUGC", | |
"CC(C)CC1=CC=C(C=C1)C(C)C(=O)O" | |
], | |
[ | |
"AUGCGAUCGACGUACGUUAGCCGUAGCGUAGCUAGUGUAGCUAGUAGCU", | |
"C1=CC=C(C=C1)NC(=O)C2=CC=CC=N2" | |
] | |
], | |
inputs=[target_input, drug_input], | |
outputs=prediction_output, | |
fn=predict_wrapper, | |
cache_examples=False | |
) | |
# Button click events | |
predict_btn.click( | |
fn=predict_wrapper, | |
inputs=[target_input, drug_input], | |
outputs=prediction_output | |
) | |
def visualize_and_update(target_seq, drug_smiles): | |
"""Generate visualizations and update both status and state""" | |
img1, img2, img3, status = visualize_wrapper(target_seq, drug_smiles) | |
# Combine prediction result with visualization status | |
combined_status = status + "\n\nVisualization analysis complete. Please navigate to the Visualizations tab to view the generated images." | |
return img1, img2, img3, combined_status | |
visualize_btn.click( | |
fn=visualize_and_update, | |
inputs=[target_input, drug_input], | |
outputs=[viz_state1, viz_state2, viz_state3, prediction_output] | |
) | |
with gr.Tab("📊 Visualizations"): | |
gr.HTML(""" | |
<div style="text-align: center; margin-bottom: 20px;"> | |
<h2 style="color: #2E86AB;">🔬 Interaction Analysis & Visualizations</h2> | |
<p style="font-size: 1.1em; color: #666;"> | |
Generated visualizations will appear here after clicking "Generate Visualizations" in the Prediction tab | |
</p> | |
</div> | |
""") | |
# Visualization outputs - Large and vertically aligned | |
viz_image1 = gr.Image( | |
label="Cross-Attention Heatmap", | |
type="pil", | |
interactive=False, | |
container=True, | |
height=500, | |
value=create_placeholder_image(text="Cross-Attention Heatmap\n(Generate visualizations in the Prediction tab)") | |
) | |
viz_image2 = gr.Image( | |
label="Raw pKd Contribution Visualization", | |
type="pil", | |
interactive=False, | |
container=True, | |
height=500, | |
value=create_placeholder_image(text="Raw pKd Contribution\n(Generate visualizations in the Prediction tab)") | |
) | |
viz_image3 = gr.Image( | |
label="Normalized pKd Contribution Visualization", | |
type="pil", | |
interactive=False, | |
container=True, | |
height=500, | |
value=create_placeholder_image(text="Normalized pKd Contribution\n(Generate visualizations in the Prediction tab)") | |
) | |
# Update visualization images when state changes | |
viz_state1.change( | |
fn=lambda x: x, | |
inputs=viz_state1, | |
outputs=viz_image1 | |
) | |
viz_state2.change( | |
fn=lambda x: x, | |
inputs=viz_state2, | |
outputs=viz_image2 | |
) | |
viz_state3.change( | |
fn=lambda x: x, | |
inputs=viz_state3, | |
outputs=viz_image3 | |
) | |
with gr.Tab("⚙️ Model Settings"): | |
gr.HTML("<h3 style='color: #2E86AB;'>Model Configuration</h3>") | |
model_path_input = gr.Textbox( | |
label="Model Path", | |
value="./", | |
placeholder="Path to model directory" | |
) | |
load_model_btn = gr.Button("📥 Load Model", variant="secondary") | |
model_status = gr.Textbox( | |
label="Status", | |
interactive=False, | |
value="No model loaded" | |
) | |
load_model_btn.click( | |
fn=load_model_wrapper, | |
inputs=model_path_input, | |
outputs=model_status | |
) | |
with gr.Tab("📊 Dataset"): | |
gr.Markdown(""" | |
## Training and Test Datasets | |
### Fine-tuning Dataset (Training) | |
The model was trained on a dataset comprising **1,439 RNA–drug interaction pairs**, including: | |
- **759 unique compounds** (SMILES representations) | |
- **294 unique RNA sequences** | |
- Dissociation constants (pKd values) for binding affinity prediction | |
**RNA Sequence Distribution by Type:** | |
| RNA Sequence Type | Number of Interactions | | |
|-------------------|------------------------| | |
| Aptamers | 520 | | |
| Ribosomal | 295 | | |
| Viral RNAs | 281 | | |
| miRNAs | 146 | | |
| Riboswitches | 100 | | |
| Repeats | 97 | | |
| **Total** | **1,439** | | |
### External Evaluation Dataset (Test) | |
Model validation was performed using external ROBIN classification datasets containing **5,534 RNA–drug pairs**: | |
- **2,991 positive interactions** | |
- **2,538 negative interactions** | |
**Test Dataset Composition:** | |
- **1,617 aptamer pairs** (5 unique RNA sequences) | |
- **1,828 viral RNA pairs** (6 unique RNA sequences) | |
- **1,459 riboswitch pairs** (5 unique RNA sequences) | |
- **630 miRNA pairs** (3 unique RNA sequences) | |
### Dataset Downloads | |
- [Training Dataset Download](https://huggingface.co/spaces/IlPakoZ/DLRNA-BERTa/resolve/main/datasets/training_data.csv?download=true) | |
- [Test Dataset Download](https://huggingface.co/spaces/IlPakoZ/DLRNA-BERTa/resolve/main/datasets/test_data.csv?download=true) | |
### Citation | |
Original datasets published by: | |
**Krishnan et al.** - Available on the RSAPred website in PDF format. | |
*Reference:* | |
```bibtex | |
@article{krishnan2024reliable, | |
title={Reliable method for predicting the binding affinity of RNA-small molecule interactions using machine learning}, | |
author={Krishnan, Sowmya R and Roy, Arijit and Gromiha, M Michael}, | |
journal={Briefings in Bioinformatics}, | |
volume={25}, | |
number={2}, | |
pages={bbae002}, | |
year={2024}, | |
publisher={Oxford University Press} | |
} | |
``` | |
""") | |
with gr.Tab("ℹ️ About"): | |
gr.Markdown(""" | |
## About this application | |
This application implements DLRNA-BERTa, a Dual Langauge RoBERTa Transformer model for predicting drug to RNA target interactions. The model architecture includes: | |
- **Target encoder**: Processes RNA sequences using RNA-BERTa | |
- **Drug encoder**: Processes molecular SMILES notation using ChemBERTa | |
- **Cross-attention mechanism**: Captures interactions between drugs and targets | |
- **Regression head**: Predicts binding affinity scores (pKd values) | |
### Input requirements: | |
- **Target sequence**: RNA sequence of the target (nucleotide sequences: A, U, G, C) | |
- **Drug SMILES**: Simplified Molecular Input Line Entry System notation | |
### Model features: | |
- Cross-attention for drug-target interaction modeling | |
- Dropout for regularization | |
- Layer normalization for stable training | |
- Interpretability mode for contribution and attention visualization | |
### Usage tips: | |
1. Load a trained model using the Model Settings tab (optional) | |
2. Enter a RNA sequence and drug SMILES in the Prediction & Analysis tab | |
3. Click "Predict Interaction" for binding affinity prediction only | |
4. Click "Generate Visualizations" to create detailed interaction analysis - results will appear in the Visualizations tab | |
For best results, ensure your input sequences are properly formatted and within reasonable length limits (max 512 tokens). | |
### Visualization features: | |
- **Cross-attention heatmap**: Shows cross-attention weights between drug and target tokens | |
- **Unnormalized pKd contribution**: Shows unnormalized signed contributions from each target token (only when pKd > 0) | |
- **Normalized pKd contribution**: Shows normalized non-negative contributions from each target token | |
### Performance metrics: | |
- Training on diverse drug-target interaction datasets | |
- Evaluated using RMSE, Pearson correlation, and Concordance Index | |
- Optimized for both predictive accuracy and interpretability | |
### GitHub repository: | |
- The full model GitHub repository can be found here: https://github.com/IlPakoZ/dlrnaberta-dti-prediction | |
### Contribution: | |
- Special thanks to Umut Onur Özcan for help in developing this space:) | |
### Contact: | |
- Ziaurrehman Tanoli (ziaurrehman.tanoli@helsinki.fi) | |
Principal investigator at Institute for Molecular Medicine Finland | |
HiLIFE, University of Helsinki, Finland. | |
""") | |
# Launch the app | |
if __name__ == "__main__": | |
# Try to load model on startup | |
if os.path.exists("./config.json"): | |
app.load_model("./") | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False, | |
show_error=True | |
) |