Spaces:
Sleeping
Sleeping
""" | |
Emotion Analysis Framework | |
A framework for analyzing emotions from patient audio recordings using wav2vec2 models. | |
Author: Marek Sviderski | |
This framework supports three main tasks: | |
- Binary classification: Distinguishing between Alzheimer's Disease (AD) and Healthy Control (HC) | |
- Multiclass classification: Classifying into HC, Mild Cognitive Impairment (MCI), and AD | |
- Regression: Predicting the Mini-Mental State Examination (MMSE) score | |
This code is designed to be modular and extensible, allowing for easy integration of new models and strategies. | |
It uses dataclasses for structured data representation and provides methods for feature extraction, model loading, and predictions. | |
""" | |
import os | |
from pathlib import Path | |
from typing import Dict, List, Optional, Union, Any | |
from dataclasses import dataclass, field | |
import numpy as np | |
import pandas as pd | |
import torch | |
import joblib | |
import warnings | |
from utils.config import ProjectConfig | |
from utils.config_types import ChunkLength | |
from models.inference_wav2vec import Wav2VecInference | |
from preprocessing.flattening_statistical import StatisticalFlattening | |
from preprocessing.flattening_categorical import CategoricalFlattening | |
from preprocessing.flattening_minirocket import MiniRocketFlattening | |
warnings.filterwarnings('ignore', category=UserWarning, module='xgboost') | |
warnings.filterwarnings('ignore', message='Some weights of the model checkpoint') | |
warnings.filterwarnings('ignore', message='Some weights of Wav2Vec2ForSequenceClassification') | |
class PatientData: | |
"""Data class for patient information""" | |
patient_id: str | |
audio_path: str | |
demographics: Dict[str, Any] = field(default_factory=dict) | |
class PredictionResult: | |
"""Data class for prediction results""" | |
patient_id: str | |
task: str | |
predictions: Dict[str, Any] | |
probabilities: Optional[Dict[str, Any]] = None | |
confidence: Optional[float] = None | |
metadata: Dict[str, Any] = field(default_factory=dict) | |
def summary(self) -> str: | |
"""Return a clean summary of the prediction""" | |
if self.task == 'binary': | |
return f"Binary: {self.predictions['label']} (confidence: {self.confidence:.2f})" | |
elif self.task == 'multiclass': | |
return f"Multiclass: {self.predictions['label']} (confidence: {self.confidence:.2f})" | |
elif self.task == 'regression': | |
return f"MMSE Score: {self.predictions['mmse_score']:.1f} ± {self.predictions['std']:.1f}" | |
else: | |
return str(self.predictions) | |
class EmotionAnalysisFramework: | |
""" | |
End-to-end framework for emotion analysis from patient recordings. | |
This framework provides three prediction tasks: | |
- Binary classification: AD vs Healthy Control (HC) | |
- Multiclass classification: HC vs MCI vs AD | |
- Regression: MMSE score prediction | |
Args: | |
config_path: Optional path to custom configuration | |
model_dir: Path to directory containing model weights | |
verbose: Whether to print detailed progress information | |
""" | |
def __init__(self, config_path: Optional[str] = None, | |
model_dir: Optional[str] = None, | |
verbose: bool = False): | |
self.config = ProjectConfig(config_path) if config_path else ProjectConfig() | |
self.model_dir = model_dir | |
self.verbose = verbose | |
self.wav2vec_model = None | |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
self.models = { | |
'binary': {}, | |
'multiclass': {}, | |
'regression': {} | |
} | |
self.strategies = {} | |
self._initialize_strategies() | |
self._load_models() | |
def _log(self, message: str): | |
"""Print message only if verbose mode is enabled""" | |
if self.verbose: | |
print(message) | |
def _initialize_strategies(self): | |
"""Initialize all flattening strategies""" | |
self.strategies = { | |
'statistical': StatisticalFlattening(), | |
'categorical': CategoricalFlattening(), | |
'minirocket': MiniRocketFlattening() | |
} | |
def _load_models(self): | |
"""Load all trained models""" | |
if not self.model_dir: | |
# Try to find models in package directory | |
package_dir = os.path.dirname(os.path.abspath(__file__)) | |
self.model_dir = os.path.join(package_dir, "model_weights") | |
if not os.path.exists(self.model_dir): | |
raise ValueError(f"Model directory not found: {self.model_dir}") | |
# Load models for each task | |
self._load_task_models('binary', os.path.join(self.model_dir, "binary")) | |
self._load_task_models('multiclass', os.path.join(self.model_dir, "multiclass")) | |
self._load_task_models('regression', os.path.join(self.model_dir, "regression")) | |
# Verify models were loaded | |
for task in ['binary', 'multiclass', 'regression']: | |
if not self.models[task]: | |
raise ValueError(f"No {task} models found in {self.model_dir}") | |
def _load_task_models(self, task: str, path: str): | |
"""Load models for a specific task""" | |
if not os.path.exists(path): | |
self._log(f"Warning: {task} model path not found: {path}") | |
return | |
model_type = 'simple' if task in ['binary', 'regression'] else 'fusion' | |
self.models[task][model_type] = {} | |
model_files = [f for f in os.listdir(path) | |
if f.startswith('model_fold_') and f.endswith('.joblib')] | |
for file in model_files: | |
fold_num = file.split('_')[-1].replace('.joblib', '') | |
model_path = os.path.join(path, file) | |
try: | |
self.models[task][model_type][fold_num] = joblib.load(model_path) | |
self._log(f"Loaded {task} {model_type} model fold {fold_num}") | |
except Exception as e: | |
self._log(f"Error loading model {model_path}: {e}") | |
def _extract_wav2vec_features(self, audio_path: str, chunk_length: ChunkLength) -> pd.DataFrame: | |
"""Extract wav2vec features from audio file""" | |
if self.wav2vec_model is None: | |
self.wav2vec_model = Wav2VecInference(self.config, verbose=self.verbose) | |
self.wav2vec_model.load_model() | |
chunk_config = self.config.get_chunk_params(chunk_length) | |
emotions_over_time = self.wav2vec_model.analyze_emotions_over_time( | |
audio_path, | |
segment_duration=chunk_config.segment_duration, | |
overlap_duration=chunk_config.overlap_duration | |
) | |
rows = [] | |
for start, end, emotions in emotions_over_time: | |
rows.append({ | |
'filename': str(Path(audio_path).stem), | |
'start': start, | |
'end': end, | |
**emotions | |
}) | |
return pd.DataFrame(rows) | |
def _prepare_features(self, patient_data: PatientData, task: str) -> Dict[str, pd.DataFrame]: | |
"""Prepare features for a specific task - FIXED VERSION""" | |
prepared_data = {} | |
if task == 'binary': | |
# Binary uses statistical flattening with 3.5s chunks | |
df_3_5 = self._extract_wav2vec_features(patient_data.audio_path, ChunkLength.LENGTH_3_5) | |
flattened = self.strategies['statistical'].flatten_dataframe(df_3_5) | |
# Add demographics more efficiently | |
if patient_data.demographics: | |
# Create a copy and add all demographics at once | |
demo_df = pd.DataFrame([patient_data.demographics] * len(flattened)) | |
flattened = pd.concat([flattened, demo_df], axis=1) | |
prepared_data['simple'] = flattened | |
elif task == 'multiclass': | |
# Multiclass uses categorical flattening with different chunk lengths | |
df_1_5 = self._extract_wav2vec_features(patient_data.audio_path, ChunkLength.LENGTH_1_5) | |
df_4_5 = self._extract_wav2vec_features(patient_data.audio_path, ChunkLength.LENGTH_4_5) | |
flattened_1_5 = self.strategies['categorical'].flatten_dataframe(df_1_5) | |
flattened_4_5 = self.strategies['categorical'].flatten_dataframe(df_4_5) | |
if patient_data.demographics: | |
# Add demographics efficiently | |
demo_df = pd.DataFrame([patient_data.demographics]) | |
flattened_1_5 = pd.concat([flattened_1_5, demo_df], axis=1) | |
flattened_4_5 = pd.concat([flattened_4_5, demo_df], axis=1) | |
prepared_data['chunk1'] = flattened_1_5 | |
prepared_data['chunk2'] = flattened_4_5 | |
elif task == 'regression': | |
# Regression uses minirocket for 1.5s and 3.5s, categorical for 4.5s | |
df_1_5 = self._extract_wav2vec_features(patient_data.audio_path, ChunkLength.LENGTH_1_5) | |
df_3_5 = self._extract_wav2vec_features(patient_data.audio_path, ChunkLength.LENGTH_3_5) | |
df_4_5 = self._extract_wav2vec_features(patient_data.audio_path, ChunkLength.LENGTH_4_5) | |
flattened_1_5 = self.strategies['minirocket'].flatten_dataframe(df_1_5) | |
flattened_3_5 = self.strategies['minirocket'].flatten_dataframe(df_3_5) | |
flattened_4_5 = self.strategies['categorical'].flatten_dataframe(df_4_5) | |
if patient_data.demographics: | |
# Add demographics efficiently to 4_5 only | |
demo_df = pd.DataFrame([patient_data.demographics]) | |
flattened_4_5 = pd.concat([flattened_4_5, demo_df], axis=1) | |
prepared_data['chunk_1_5'] = flattened_1_5 | |
prepared_data['chunk_3_5'] = flattened_3_5 | |
prepared_data['chunk_4_5_demo'] = flattened_4_5 | |
return prepared_data | |
def _predict_binary(self, features: Dict[str, pd.DataFrame]) -> PredictionResult: | |
"""Make binary classification predictions - FIXED""" | |
model_type = 'simple' | |
if model_type not in self.models['binary'] or not self.models['binary'][model_type]: | |
raise ValueError(f"No binary models loaded") | |
fold_predictions = [] | |
fold_probabilities = [] | |
# Get the patient ID from features | |
patient_id = 'unknown' | |
if 'filename' in features[model_type].columns: | |
patient_id = features[model_type]['filename'].iloc[0] | |
# Get predictions from all folds | |
for fold_num, model in self.models['binary'][model_type].items(): | |
try: | |
X = features[model_type] | |
# Ensure correct feature order | |
if hasattr(model, 'feature_names_in_'): | |
# Only use features that the model was trained on | |
model_features = [f for f in model.feature_names_in_ if f in X.columns] | |
X = X[model_features] | |
pred = model.predict(X) | |
pred_proba = model.predict_proba(X) | |
fold_predictions.append(pred[0]) | |
fold_probabilities.append(pred_proba[0]) | |
except Exception as e: | |
self._log(f"Error in fold {fold_num}: {e}") | |
continue | |
if not fold_predictions: | |
raise ValueError("No successful predictions from any fold") | |
# Aggregate predictions (majority vote) | |
final_prediction = int(np.round(np.mean(fold_predictions))) | |
mean_probabilities = np.mean(fold_probabilities, axis=0) | |
return PredictionResult( | |
patient_id=patient_id, | |
task='binary', | |
predictions={'class': final_prediction, 'label': 'AD' if final_prediction else 'HC'}, | |
probabilities={'HC': float(mean_probabilities[0]), 'AD': float(mean_probabilities[1])}, | |
confidence=float(np.max(mean_probabilities)), | |
metadata={'model_type': model_type, 'num_folds': len(fold_predictions)} | |
) | |
def _predict_multiclass(self, features: Dict[str, pd.DataFrame]) -> PredictionResult: | |
"""Make multiclass classification predictions - FIXED""" | |
if 'fusion' not in self.models['multiclass'] or not self.models['multiclass']['fusion']: | |
raise ValueError("No multiclass fusion model loaded") | |
fold_predictions = [] | |
fold_probabilities = [] | |
class_labels = ['HC', 'MCI', 'AD'] | |
# Get patient ID | |
patient_id = 'unknown' | |
if 'filename' in features['chunk1'].columns: | |
patient_id = features['chunk1']['filename'].iloc[0] | |
for fold_num, model_pack in self.models['multiclass']['fusion'].items(): | |
try: | |
# Prepare features for each model | |
model1 = model_pack['chunk1'] | |
model2 = model_pack['chunk2'] | |
# Get features that the models were trained on | |
X_chunk1 = features['chunk1'] | |
X_chunk2 = features['chunk2'] | |
# Ensure we have the right features | |
if hasattr(model1, 'feature_names_in_'): | |
model1_features = [f for f in model1.feature_names_in_ if f in X_chunk1.columns] | |
X_chunk1 = X_chunk1[model1_features] | |
if hasattr(model2, 'feature_names_in_'): | |
model2_features = [f for f in model2.feature_names_in_ if f in X_chunk2.columns] | |
X_chunk2 = X_chunk2[model2_features] | |
pred_proba_1 = model1.predict_proba(X_chunk1) | |
pred_proba_2 = model2.predict_proba(X_chunk2) | |
# Apply fusion weights | |
weights = model_pack.get('weights', [0.5, 0.5]) | |
fusion_proba = weights[0] * pred_proba_1 + weights[1] * pred_proba_2 | |
pred = np.argmax(fusion_proba, axis=1) | |
fold_predictions.append(pred[0]) | |
fold_probabilities.append(fusion_proba[0]) | |
except Exception as e: | |
self._log(f"Error in multiclass fold {fold_num}: {e}") | |
continue | |
if not fold_predictions: | |
raise ValueError("No successful multiclass predictions from any fold") | |
# Aggregate predictions | |
final_prediction = int(np.round(np.mean(fold_predictions))) | |
mean_probabilities = np.mean(fold_probabilities, axis=0) | |
prob_dict = {label: float(prob) for label, prob in zip(class_labels, mean_probabilities)} | |
return PredictionResult( | |
patient_id=patient_id, | |
task='multiclass', | |
predictions={'class': final_prediction, 'label': class_labels[final_prediction]}, | |
probabilities=prob_dict, | |
confidence=float(np.max(mean_probabilities)), | |
metadata={'num_folds': len(fold_predictions)} | |
) | |
def _predict_regression(self, features: Dict[str, pd.DataFrame]) -> PredictionResult: | |
"""Make regression predictions - FIXED""" | |
model_type = 'simple' | |
if model_type not in self.models['regression'] or not self.models['regression'][model_type]: | |
raise ValueError("No regression models loaded") | |
fold_predictions = [] | |
# Get patient ID | |
patient_id = 'unknown' | |
if 'filename' in features['chunk_1_5'].columns: | |
patient_id = features['chunk_1_5']['filename'].iloc[0] | |
for fold_num, model_pack in self.models['regression'][model_type].items(): | |
try: | |
# Simple fusion prediction | |
models = model_pack['models'] | |
weights = model_pack['weights'] | |
# Get predictions from each model with proper feature selection | |
model1 = models["ridge_1_5_minirocket"] | |
X1 = features['chunk_1_5'] | |
if hasattr(model1, 'feature_names_in_'): | |
model1_features = [f for f in model1.feature_names_in_ if f in X1.columns] | |
X1 = X1[model1_features] | |
pred1 = model1.predict(X1) | |
model2 = models["ridge_3_5_minirocket"] | |
X2 = features['chunk_3_5'] | |
if hasattr(model2, 'feature_names_in_'): | |
model2_features = [f for f in model2.feature_names_in_ if f in X2.columns] | |
X2 = X2[model2_features] | |
pred2 = model2.predict(X2) | |
model3 = models["ridge_4_5_categorical"] | |
X3 = features['chunk_4_5_demo'] | |
if hasattr(model3, 'feature_names_in_'): | |
model3_features = [f for f in model3.feature_names_in_ if f in X3.columns] | |
X3 = X3[model3_features] | |
pred3 = model3.predict(X3) | |
final_pred = weights[0] * pred1 + weights[1] * pred2 + weights[2] * pred3 | |
fold_predictions.append(final_pred[0]) | |
except Exception as e: | |
self._log(f"Error in regression fold {fold_num}: {e}") | |
continue | |
if not fold_predictions: | |
raise ValueError("No successful regression predictions from any fold") | |
# Aggregate predictions | |
final_prediction = float(np.mean(fold_predictions)) | |
std_prediction = float(np.std(fold_predictions)) | |
return PredictionResult( | |
patient_id=patient_id, | |
task='regression', | |
predictions={'mmse_score': final_prediction, 'std': std_prediction}, | |
confidence=1.0 / (1.0 + std_prediction), | |
metadata={'model_type': model_type, 'num_folds': len(fold_predictions)} | |
) | |
def _predict_regression(self, features: Dict[str, pd.DataFrame]) -> PredictionResult: | |
"""Make regression predictions - FIXED""" | |
model_type = 'simple' | |
if model_type not in self.models['regression'] or not self.models['regression'][model_type]: | |
raise ValueError("No regression models loaded") | |
fold_predictions = [] | |
# Get patient ID | |
patient_id = 'unknown' | |
if 'filename' in features['chunk_1_5'].columns: | |
patient_id = features['chunk_1_5']['filename'].iloc[0] | |
for fold_num, model_pack in self.models['regression'][model_type].items(): | |
try: | |
# Simple fusion prediction | |
models = model_pack['models'] | |
weights = model_pack['weights'] | |
# Get predictions from each model with proper feature selection | |
model1 = models["ridge_1_5_minirocket"] | |
X1 = features['chunk_1_5'] | |
if hasattr(model1, 'feature_names_in_'): | |
model1_features = [f for f in model1.feature_names_in_ if f in X1.columns] | |
X1 = X1[model1_features] | |
pred1 = model1.predict(X1) | |
model2 = models["ridge_3_5_minirocket"] | |
X2 = features['chunk_3_5'] | |
if hasattr(model2, 'feature_names_in_'): | |
model2_features = [f for f in model2.feature_names_in_ if f in X2.columns] | |
X2 = X2[model2_features] | |
pred2 = model2.predict(X2) | |
model3 = models["ridge_4_5_categorical"] | |
X3 = features['chunk_4_5_demo'] | |
if hasattr(model3, 'feature_names_in_'): | |
model3_features = [f for f in model3.feature_names_in_ if f in X3.columns] | |
X3 = X3[model3_features] | |
pred3 = model3.predict(X3) | |
final_pred = weights[0] * pred1 + weights[1] * pred2 + weights[2] * pred3 | |
fold_predictions.append(final_pred[0]) | |
except Exception as e: | |
self._log(f"Error in regression fold {fold_num}: {e}") | |
continue | |
if not fold_predictions: | |
raise ValueError("No successful regression predictions from any fold") | |
# Aggregate predictions | |
final_prediction = float(np.mean(fold_predictions)) | |
std_prediction = float(np.std(fold_predictions)) | |
return PredictionResult( | |
patient_id=patient_id, | |
task='regression', | |
predictions={'mmse_score': final_prediction, 'std': std_prediction}, | |
confidence=1.0 / (1.0 + std_prediction), | |
metadata={'model_type': model_type, 'num_folds': len(fold_predictions)} | |
) | |
def predict(self, patient_data: Union[PatientData, List[PatientData]], | |
task: Optional[str] = None) -> Union[ | |
PredictionResult, List[PredictionResult], Dict[str, PredictionResult]]: | |
""" | |
Make predictions for one or more patients. | |
Args: | |
patient_data: Single PatientData or list of PatientData objects | |
task: Specific task ('binary', 'multiclass', 'regression') or None for all | |
Returns: | |
Single PredictionResult, list of results, or dict of results by task | |
""" | |
# Handle single patient | |
if isinstance(patient_data, PatientData): | |
patient_data = [patient_data] | |
single_patient = True | |
else: | |
single_patient = False | |
results = [] | |
for patient in patient_data: | |
patient_results = {} | |
tasks_to_run = [task] if task else ['binary', 'multiclass', 'regression'] | |
for current_task in tasks_to_run: | |
try: | |
# Prepare features for the task | |
features = self._prepare_features(patient, current_task) | |
# Make predictions based on task | |
if current_task == 'binary': | |
result = self._predict_binary(features) | |
elif current_task == 'multiclass': | |
result = self._predict_multiclass(features) | |
elif current_task == 'regression': | |
result = self._predict_regression(features) | |
else: | |
raise ValueError(f"Unknown task: {current_task}") | |
# Fix: use patient from the loop, not patient_data | |
result.patient_id = patient.patient_id | |
patient_results[current_task] = result | |
except Exception as e: | |
self._log(f"Error predicting {current_task} for patient {patient.patient_id}: {e}") | |
patient_results[current_task] = PredictionResult( | |
patient_id=patient.patient_id, | |
task=current_task, | |
predictions={'error': str(e)}, | |
metadata={'status': 'failed'} | |
) | |
# Return appropriate format | |
if task: # Single task | |
results.append(patient_results[task]) | |
else: # All tasks | |
results.append(patient_results) | |
# Format return based on input | |
if single_patient: | |
return results[0] | |
else: | |
return results |