F0_Energy_joint_VQVAE_embeddings-interp / prosody_embedding_pipeline.py
Daporte's picture
Rename prosodic_embedding_pipeline.py to prosody_embedding_pipeline.py
70aa22e verified
from transformers import Pipeline
import numpy as np
import torch
from typing import Dict, Union, List, Optional
from pathlib import Path
import logging
from datasets import Dataset
logger = logging.getLogger(__name__)
class ProsodyEmbeddingPipeline(Pipeline):
def __init__(
self,
speaker_stats,
f0_interp,
f0_normalize,
stats_dir: Optional[str] = None,
**kwargs
):
super().__init__(**kwargs)
self.stats_dir = Path(stats_dir) if stats_dir else None
self.speaker_stats = speaker_stats
self.f0_interp = f0_interp
self.f0_normalize = f0_normalize
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
forward_kwargs = {}
postprocess_kwargs = {
"return_tensors": kwargs.pop("return_tensors", "pt")
}
return preprocess_kwargs, forward_kwargs, postprocess_kwargs
def preprocess(self, inputs: Union[str, Dict, Dataset]) -> Dict:
"""Preprocess inputs"""
spkr_id = inputs['speaker_id']
stats = self.speaker_stats[spkr_id]
if self.f0_interp:
f0 = torch.Tensor(inputs['f0_interp'])
else:
f0 = torch.Tensor(inputs['f0'])
f0_orig = f0.clone() # Save original f0 before normalization
intensity = torch.Tensor(inputs['intensity'])
intensity_orig = intensity.clone() # Save original intensity before normalization
if self.f0_normalize:
ii = f0 != 0
if stats.f0_std != 0:
f0[ii] = (f0[ii] - stats.f0_mean) / stats.f0_std
intensity_ii = intensity != 0
if stats.intensity_std != 0:
intensity[intensity_ii] = (intensity[intensity_ii] - stats.intensity_mean) / stats.intensity_std
if not self.f0_interp:
zero_indices = f0 == 0
zero_mask = zero_indices * 1.0
inputs = {
'f0': f0,
'intensity': intensity,
'zero_mask': zero_mask if self.f0_normalize and not self.f0_interp else None,
'f0_mean': stats.f0_mean,
'f0_std': stats.f0_std,
'intensity_mean': stats.intensity_mean,
'intensity_std': stats.intensity_std,
'f0_orig': f0_orig, # original features before normalization
'intensity_orig': intensity_orig, # original features
'speaker_id': spkr_id,
}
return inputs
def _forward(self, features: Dict) -> Dict:
"""Run the model on the preprocessed features"""
self.model.eval()
f0 = torch.Tensor(features['f0'])
intensity = torch.Tensor(features['intensity'])
if self.f0_interp:
stacked_features = torch.stack([f0, intensity], dim=0).to(self.model.device)
else:
zero_mask = torch.Tensor(features['zero_mask'])
stacked_features = torch.stack([f0, intensity, zero_mask], dim=0).to(self.model.device)
stacked_features = stacked_features.unsqueeze(0)
with torch.no_grad():
model_outputs = self.model(features=stacked_features)
outputs = {
**model_outputs,
'input_features': {
'zero_mask': zero_mask if self.f0_normalize and not self.f0_interp else None,
'f0_orig': features['f0_orig'],
'f0_mean': features['f0_mean'],
'f0_std': features['f0_std'],
'intensity_mean': features['intensity_mean'],
'intensity_std': features['intensity_std'],
'intensity_orig': features['intensity_orig']
}
}
return outputs
def postprocess(self, outputs: Dict, return_tensors: str = "pt") -> Dict:
"""Convert outputs to the desired format and calculate metrics"""
input_f0 = outputs['input_features']['f0_orig']
output_f0 = outputs['f0'][0,:,:]
f0_recon = output_f0
# revert normalization
if self.f0_normalize:
f0_recon[0] = f0_recon[0] * outputs['input_features']["f0_std"] + outputs['input_features']["f0_mean"]
if not self.f0_interp:
mask = torch.where(f0_recon[2] < 0.5, torch.tensor([1.0]), torch.tensor([0.0]))
f0_recon[0] = (f0_recon[0] * mask)
f0_recon[1] = f0_recon[1] * outputs['input_features']["intensity_std"] + outputs['input_features']["intensity_mean"]
epsilon = 1e-10
DIFF_THRESHOLD = 0.2
# F0 metrics calculation
input_f0_np = input_f0.cpu().numpy()
output_f0_np = f0_recon[0].cpu().numpy() # Use f0_recon[0] instead of output_f0
# Truncate both arrays to multiple of 16
length = len(input_f0_np)
truncated_length = (length // 16) * 16
input_f0_np = input_f0_np[:truncated_length]
output_f0_np = output_f0_np[:truncated_length]
input_f0_safe = np.where(input_f0_np == 0, epsilon, input_f0_np)
rel_diff = np.abs(input_f0_np - output_f0_np) / np.abs(input_f0_safe)
diff_points = rel_diff > DIFF_THRESHOLD
diff_count = np.sum(diff_points)
total_points = len(input_f0_np)
f0_large_diff_percent = (diff_count / total_points) * 100
# intensity metrics calculation
input_intensity_np = outputs['input_features']['intensity_orig'].cpu().numpy()
output_intensity_np = f0_recon[1].cpu().numpy()
# Truncate intensity arrays to multiple of 16
length = len(input_intensity_np)
truncated_length = (length // 16) * 16
input_intensity_np = input_intensity_np[:truncated_length]
output_intensity_np = output_intensity_np[:truncated_length]
intensity_rmse = np.sqrt(np.mean((input_intensity_np - output_intensity_np) ** 2))
outputs['f0_recon'] = output_f0_np
outputs['intensity_recon'] = output_intensity_np
# Add metrics to outputs
outputs['metrics'] = {
'f0_large_diff_percent': f0_large_diff_percent.item(),
'intensity_rmse': float(intensity_rmse)
}
print(f"outputs['metrics']", outputs['metrics'])
if return_tensors == "np":
outputs = {
k: v.cpu().numpy() if torch.is_tensor(v) else v
for k, v in outputs.items()
}
return outputs