emotional_ad_predictions / emotion_analysis_framework.py
sivdma's picture
Upload 29 files
61c258f verified
"""
Emotion Analysis Framework
A framework for analyzing emotions from patient audio recordings using wav2vec2 models.
Author: Marek Sviderski
This framework supports three main tasks:
- Binary classification: Distinguishing between Alzheimer's Disease (AD) and Healthy Control (HC)
- Multiclass classification: Classifying into HC, Mild Cognitive Impairment (MCI), and AD
- Regression: Predicting the Mini-Mental State Examination (MMSE) score
This code is designed to be modular and extensible, allowing for easy integration of new models and strategies.
It uses dataclasses for structured data representation and provides methods for feature extraction, model loading, and predictions.
"""
import os
from pathlib import Path
from typing import Dict, List, Optional, Union, Any
from dataclasses import dataclass, field
import numpy as np
import pandas as pd
import torch
import joblib
import warnings
from utils.config import ProjectConfig
from utils.config_types import ChunkLength
from models.inference_wav2vec import Wav2VecInference
from preprocessing.flattening_statistical import StatisticalFlattening
from preprocessing.flattening_categorical import CategoricalFlattening
from preprocessing.flattening_minirocket import MiniRocketFlattening
warnings.filterwarnings('ignore', category=UserWarning, module='xgboost')
warnings.filterwarnings('ignore', message='Some weights of the model checkpoint')
warnings.filterwarnings('ignore', message='Some weights of Wav2Vec2ForSequenceClassification')
@dataclass
class PatientData:
"""Data class for patient information"""
patient_id: str
audio_path: str
demographics: Dict[str, Any] = field(default_factory=dict)
@dataclass
class PredictionResult:
"""Data class for prediction results"""
patient_id: str
task: str
predictions: Dict[str, Any]
probabilities: Optional[Dict[str, Any]] = None
confidence: Optional[float] = None
metadata: Dict[str, Any] = field(default_factory=dict)
def summary(self) -> str:
"""Return a clean summary of the prediction"""
if self.task == 'binary':
return f"Binary: {self.predictions['label']} (confidence: {self.confidence:.2f})"
elif self.task == 'multiclass':
return f"Multiclass: {self.predictions['label']} (confidence: {self.confidence:.2f})"
elif self.task == 'regression':
return f"MMSE Score: {self.predictions['mmse_score']:.1f} ± {self.predictions['std']:.1f}"
else:
return str(self.predictions)
class EmotionAnalysisFramework:
"""
End-to-end framework for emotion analysis from patient recordings.
This framework provides three prediction tasks:
- Binary classification: AD vs Healthy Control (HC)
- Multiclass classification: HC vs MCI vs AD
- Regression: MMSE score prediction
Args:
config_path: Optional path to custom configuration
model_dir: Path to directory containing model weights
verbose: Whether to print detailed progress information
"""
def __init__(self, config_path: Optional[str] = None,
model_dir: Optional[str] = None,
verbose: bool = False):
self.config = ProjectConfig(config_path) if config_path else ProjectConfig()
self.model_dir = model_dir
self.verbose = verbose
self.wav2vec_model = None
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.models = {
'binary': {},
'multiclass': {},
'regression': {}
}
self.strategies = {}
self._initialize_strategies()
self._load_models()
def _log(self, message: str):
"""Print message only if verbose mode is enabled"""
if self.verbose:
print(message)
def _initialize_strategies(self):
"""Initialize all flattening strategies"""
self.strategies = {
'statistical': StatisticalFlattening(),
'categorical': CategoricalFlattening(),
'minirocket': MiniRocketFlattening()
}
def _load_models(self):
"""Load all trained models"""
if not self.model_dir:
# Try to find models in package directory
package_dir = os.path.dirname(os.path.abspath(__file__))
self.model_dir = os.path.join(package_dir, "model_weights")
if not os.path.exists(self.model_dir):
raise ValueError(f"Model directory not found: {self.model_dir}")
# Load models for each task
self._load_task_models('binary', os.path.join(self.model_dir, "binary"))
self._load_task_models('multiclass', os.path.join(self.model_dir, "multiclass"))
self._load_task_models('regression', os.path.join(self.model_dir, "regression"))
# Verify models were loaded
for task in ['binary', 'multiclass', 'regression']:
if not self.models[task]:
raise ValueError(f"No {task} models found in {self.model_dir}")
def _load_task_models(self, task: str, path: str):
"""Load models for a specific task"""
if not os.path.exists(path):
self._log(f"Warning: {task} model path not found: {path}")
return
model_type = 'simple' if task in ['binary', 'regression'] else 'fusion'
self.models[task][model_type] = {}
model_files = [f for f in os.listdir(path)
if f.startswith('model_fold_') and f.endswith('.joblib')]
for file in model_files:
fold_num = file.split('_')[-1].replace('.joblib', '')
model_path = os.path.join(path, file)
try:
self.models[task][model_type][fold_num] = joblib.load(model_path)
self._log(f"Loaded {task} {model_type} model fold {fold_num}")
except Exception as e:
self._log(f"Error loading model {model_path}: {e}")
def _extract_wav2vec_features(self, audio_path: str, chunk_length: ChunkLength) -> pd.DataFrame:
"""Extract wav2vec features from audio file"""
if self.wav2vec_model is None:
self.wav2vec_model = Wav2VecInference(self.config, verbose=self.verbose)
self.wav2vec_model.load_model()
chunk_config = self.config.get_chunk_params(chunk_length)
emotions_over_time = self.wav2vec_model.analyze_emotions_over_time(
audio_path,
segment_duration=chunk_config.segment_duration,
overlap_duration=chunk_config.overlap_duration
)
rows = []
for start, end, emotions in emotions_over_time:
rows.append({
'filename': str(Path(audio_path).stem),
'start': start,
'end': end,
**emotions
})
return pd.DataFrame(rows)
def _prepare_features(self, patient_data: PatientData, task: str) -> Dict[str, pd.DataFrame]:
"""Prepare features for a specific task - FIXED VERSION"""
prepared_data = {}
if task == 'binary':
# Binary uses statistical flattening with 3.5s chunks
df_3_5 = self._extract_wav2vec_features(patient_data.audio_path, ChunkLength.LENGTH_3_5)
flattened = self.strategies['statistical'].flatten_dataframe(df_3_5)
# Add demographics more efficiently
if patient_data.demographics:
# Create a copy and add all demographics at once
demo_df = pd.DataFrame([patient_data.demographics] * len(flattened))
flattened = pd.concat([flattened, demo_df], axis=1)
prepared_data['simple'] = flattened
elif task == 'multiclass':
# Multiclass uses categorical flattening with different chunk lengths
df_1_5 = self._extract_wav2vec_features(patient_data.audio_path, ChunkLength.LENGTH_1_5)
df_4_5 = self._extract_wav2vec_features(patient_data.audio_path, ChunkLength.LENGTH_4_5)
flattened_1_5 = self.strategies['categorical'].flatten_dataframe(df_1_5)
flattened_4_5 = self.strategies['categorical'].flatten_dataframe(df_4_5)
if patient_data.demographics:
# Add demographics efficiently
demo_df = pd.DataFrame([patient_data.demographics])
flattened_1_5 = pd.concat([flattened_1_5, demo_df], axis=1)
flattened_4_5 = pd.concat([flattened_4_5, demo_df], axis=1)
prepared_data['chunk1'] = flattened_1_5
prepared_data['chunk2'] = flattened_4_5
elif task == 'regression':
# Regression uses minirocket for 1.5s and 3.5s, categorical for 4.5s
df_1_5 = self._extract_wav2vec_features(patient_data.audio_path, ChunkLength.LENGTH_1_5)
df_3_5 = self._extract_wav2vec_features(patient_data.audio_path, ChunkLength.LENGTH_3_5)
df_4_5 = self._extract_wav2vec_features(patient_data.audio_path, ChunkLength.LENGTH_4_5)
flattened_1_5 = self.strategies['minirocket'].flatten_dataframe(df_1_5)
flattened_3_5 = self.strategies['minirocket'].flatten_dataframe(df_3_5)
flattened_4_5 = self.strategies['categorical'].flatten_dataframe(df_4_5)
if patient_data.demographics:
# Add demographics efficiently to 4_5 only
demo_df = pd.DataFrame([patient_data.demographics])
flattened_4_5 = pd.concat([flattened_4_5, demo_df], axis=1)
prepared_data['chunk_1_5'] = flattened_1_5
prepared_data['chunk_3_5'] = flattened_3_5
prepared_data['chunk_4_5_demo'] = flattened_4_5
return prepared_data
def _predict_binary(self, features: Dict[str, pd.DataFrame]) -> PredictionResult:
"""Make binary classification predictions - FIXED"""
model_type = 'simple'
if model_type not in self.models['binary'] or not self.models['binary'][model_type]:
raise ValueError(f"No binary models loaded")
fold_predictions = []
fold_probabilities = []
# Get the patient ID from features
patient_id = 'unknown'
if 'filename' in features[model_type].columns:
patient_id = features[model_type]['filename'].iloc[0]
# Get predictions from all folds
for fold_num, model in self.models['binary'][model_type].items():
try:
X = features[model_type]
# Ensure correct feature order
if hasattr(model, 'feature_names_in_'):
# Only use features that the model was trained on
model_features = [f for f in model.feature_names_in_ if f in X.columns]
X = X[model_features]
pred = model.predict(X)
pred_proba = model.predict_proba(X)
fold_predictions.append(pred[0])
fold_probabilities.append(pred_proba[0])
except Exception as e:
self._log(f"Error in fold {fold_num}: {e}")
continue
if not fold_predictions:
raise ValueError("No successful predictions from any fold")
# Aggregate predictions (majority vote)
final_prediction = int(np.round(np.mean(fold_predictions)))
mean_probabilities = np.mean(fold_probabilities, axis=0)
return PredictionResult(
patient_id=patient_id,
task='binary',
predictions={'class': final_prediction, 'label': 'AD' if final_prediction else 'HC'},
probabilities={'HC': float(mean_probabilities[0]), 'AD': float(mean_probabilities[1])},
confidence=float(np.max(mean_probabilities)),
metadata={'model_type': model_type, 'num_folds': len(fold_predictions)}
)
def _predict_multiclass(self, features: Dict[str, pd.DataFrame]) -> PredictionResult:
"""Make multiclass classification predictions - FIXED"""
if 'fusion' not in self.models['multiclass'] or not self.models['multiclass']['fusion']:
raise ValueError("No multiclass fusion model loaded")
fold_predictions = []
fold_probabilities = []
class_labels = ['HC', 'MCI', 'AD']
# Get patient ID
patient_id = 'unknown'
if 'filename' in features['chunk1'].columns:
patient_id = features['chunk1']['filename'].iloc[0]
for fold_num, model_pack in self.models['multiclass']['fusion'].items():
try:
# Prepare features for each model
model1 = model_pack['chunk1']
model2 = model_pack['chunk2']
# Get features that the models were trained on
X_chunk1 = features['chunk1']
X_chunk2 = features['chunk2']
# Ensure we have the right features
if hasattr(model1, 'feature_names_in_'):
model1_features = [f for f in model1.feature_names_in_ if f in X_chunk1.columns]
X_chunk1 = X_chunk1[model1_features]
if hasattr(model2, 'feature_names_in_'):
model2_features = [f for f in model2.feature_names_in_ if f in X_chunk2.columns]
X_chunk2 = X_chunk2[model2_features]
pred_proba_1 = model1.predict_proba(X_chunk1)
pred_proba_2 = model2.predict_proba(X_chunk2)
# Apply fusion weights
weights = model_pack.get('weights', [0.5, 0.5])
fusion_proba = weights[0] * pred_proba_1 + weights[1] * pred_proba_2
pred = np.argmax(fusion_proba, axis=1)
fold_predictions.append(pred[0])
fold_probabilities.append(fusion_proba[0])
except Exception as e:
self._log(f"Error in multiclass fold {fold_num}: {e}")
continue
if not fold_predictions:
raise ValueError("No successful multiclass predictions from any fold")
# Aggregate predictions
final_prediction = int(np.round(np.mean(fold_predictions)))
mean_probabilities = np.mean(fold_probabilities, axis=0)
prob_dict = {label: float(prob) for label, prob in zip(class_labels, mean_probabilities)}
return PredictionResult(
patient_id=patient_id,
task='multiclass',
predictions={'class': final_prediction, 'label': class_labels[final_prediction]},
probabilities=prob_dict,
confidence=float(np.max(mean_probabilities)),
metadata={'num_folds': len(fold_predictions)}
)
def _predict_regression(self, features: Dict[str, pd.DataFrame]) -> PredictionResult:
"""Make regression predictions - FIXED"""
model_type = 'simple'
if model_type not in self.models['regression'] or not self.models['regression'][model_type]:
raise ValueError("No regression models loaded")
fold_predictions = []
# Get patient ID
patient_id = 'unknown'
if 'filename' in features['chunk_1_5'].columns:
patient_id = features['chunk_1_5']['filename'].iloc[0]
for fold_num, model_pack in self.models['regression'][model_type].items():
try:
# Simple fusion prediction
models = model_pack['models']
weights = model_pack['weights']
# Get predictions from each model with proper feature selection
model1 = models["ridge_1_5_minirocket"]
X1 = features['chunk_1_5']
if hasattr(model1, 'feature_names_in_'):
model1_features = [f for f in model1.feature_names_in_ if f in X1.columns]
X1 = X1[model1_features]
pred1 = model1.predict(X1)
model2 = models["ridge_3_5_minirocket"]
X2 = features['chunk_3_5']
if hasattr(model2, 'feature_names_in_'):
model2_features = [f for f in model2.feature_names_in_ if f in X2.columns]
X2 = X2[model2_features]
pred2 = model2.predict(X2)
model3 = models["ridge_4_5_categorical"]
X3 = features['chunk_4_5_demo']
if hasattr(model3, 'feature_names_in_'):
model3_features = [f for f in model3.feature_names_in_ if f in X3.columns]
X3 = X3[model3_features]
pred3 = model3.predict(X3)
final_pred = weights[0] * pred1 + weights[1] * pred2 + weights[2] * pred3
fold_predictions.append(final_pred[0])
except Exception as e:
self._log(f"Error in regression fold {fold_num}: {e}")
continue
if not fold_predictions:
raise ValueError("No successful regression predictions from any fold")
# Aggregate predictions
final_prediction = float(np.mean(fold_predictions))
std_prediction = float(np.std(fold_predictions))
return PredictionResult(
patient_id=patient_id,
task='regression',
predictions={'mmse_score': final_prediction, 'std': std_prediction},
confidence=1.0 / (1.0 + std_prediction),
metadata={'model_type': model_type, 'num_folds': len(fold_predictions)}
)
def _predict_regression(self, features: Dict[str, pd.DataFrame]) -> PredictionResult:
"""Make regression predictions - FIXED"""
model_type = 'simple'
if model_type not in self.models['regression'] or not self.models['regression'][model_type]:
raise ValueError("No regression models loaded")
fold_predictions = []
# Get patient ID
patient_id = 'unknown'
if 'filename' in features['chunk_1_5'].columns:
patient_id = features['chunk_1_5']['filename'].iloc[0]
for fold_num, model_pack in self.models['regression'][model_type].items():
try:
# Simple fusion prediction
models = model_pack['models']
weights = model_pack['weights']
# Get predictions from each model with proper feature selection
model1 = models["ridge_1_5_minirocket"]
X1 = features['chunk_1_5']
if hasattr(model1, 'feature_names_in_'):
model1_features = [f for f in model1.feature_names_in_ if f in X1.columns]
X1 = X1[model1_features]
pred1 = model1.predict(X1)
model2 = models["ridge_3_5_minirocket"]
X2 = features['chunk_3_5']
if hasattr(model2, 'feature_names_in_'):
model2_features = [f for f in model2.feature_names_in_ if f in X2.columns]
X2 = X2[model2_features]
pred2 = model2.predict(X2)
model3 = models["ridge_4_5_categorical"]
X3 = features['chunk_4_5_demo']
if hasattr(model3, 'feature_names_in_'):
model3_features = [f for f in model3.feature_names_in_ if f in X3.columns]
X3 = X3[model3_features]
pred3 = model3.predict(X3)
final_pred = weights[0] * pred1 + weights[1] * pred2 + weights[2] * pred3
fold_predictions.append(final_pred[0])
except Exception as e:
self._log(f"Error in regression fold {fold_num}: {e}")
continue
if not fold_predictions:
raise ValueError("No successful regression predictions from any fold")
# Aggregate predictions
final_prediction = float(np.mean(fold_predictions))
std_prediction = float(np.std(fold_predictions))
return PredictionResult(
patient_id=patient_id,
task='regression',
predictions={'mmse_score': final_prediction, 'std': std_prediction},
confidence=1.0 / (1.0 + std_prediction),
metadata={'model_type': model_type, 'num_folds': len(fold_predictions)}
)
def predict(self, patient_data: Union[PatientData, List[PatientData]],
task: Optional[str] = None) -> Union[
PredictionResult, List[PredictionResult], Dict[str, PredictionResult]]:
"""
Make predictions for one or more patients.
Args:
patient_data: Single PatientData or list of PatientData objects
task: Specific task ('binary', 'multiclass', 'regression') or None for all
Returns:
Single PredictionResult, list of results, or dict of results by task
"""
# Handle single patient
if isinstance(patient_data, PatientData):
patient_data = [patient_data]
single_patient = True
else:
single_patient = False
results = []
for patient in patient_data:
patient_results = {}
tasks_to_run = [task] if task else ['binary', 'multiclass', 'regression']
for current_task in tasks_to_run:
try:
# Prepare features for the task
features = self._prepare_features(patient, current_task)
# Make predictions based on task
if current_task == 'binary':
result = self._predict_binary(features)
elif current_task == 'multiclass':
result = self._predict_multiclass(features)
elif current_task == 'regression':
result = self._predict_regression(features)
else:
raise ValueError(f"Unknown task: {current_task}")
# Fix: use patient from the loop, not patient_data
result.patient_id = patient.patient_id
patient_results[current_task] = result
except Exception as e:
self._log(f"Error predicting {current_task} for patient {patient.patient_id}: {e}")
patient_results[current_task] = PredictionResult(
patient_id=patient.patient_id,
task=current_task,
predictions={'error': str(e)},
metadata={'status': 'failed'}
)
# Return appropriate format
if task: # Single task
results.append(patient_results[task])
else: # All tasks
results.append(patient_results)
# Format return based on input
if single_patient:
return results[0]
else:
return results