|
|
|
import amfm_decompy.basic_tools as basic |
|
import amfm_decompy.pYAAPT as pYAAPT |
|
from dataclasses import dataclass |
|
from typing import Dict, List, Optional |
|
import numpy as np |
|
import torch |
|
import dataclasses |
|
import parselmouth |
|
from transformers import PreTrainedModel,PretrainedConfig, FeatureExtractionMixin |
|
from datasets import Dataset |
|
from scipy.signal import medfilt |
|
import scipy.interpolate as scipy_interp |
|
|
|
@dataclass |
|
class SpeakerStats: |
|
f0_mean: float |
|
f0_std: float |
|
intensity_mean: float |
|
intensity_std: float |
|
|
|
@classmethod |
|
def from_features(cls, f0_values: List[np.ndarray], intensity_values: List[np.ndarray]): |
|
"""Calculate stats from a list of features""" |
|
|
|
f0_arrays = [np.array(f0) for f0 in f0_values] |
|
intensity_arrays = [np.array(i) for i in intensity_values] |
|
|
|
|
|
f0_concat = np.concatenate([f0[f0 != 0] for f0 in f0_arrays]) |
|
intensity_concat = np.concatenate(intensity_arrays) |
|
|
|
|
|
return cls( |
|
f0_mean=float(np.mean(f0_concat)), |
|
f0_std=float(np.std(f0_concat)), |
|
intensity_mean=float(np.mean(intensity_concat)), |
|
intensity_std=float(np.std(intensity_concat)) |
|
) |
|
|
|
class ProsodyConfig(PretrainedConfig): |
|
"""Configuration class for prosody preprocessing""" |
|
model_type = "prosody_preprocessor" |
|
|
|
def __init__( |
|
self, |
|
sampling_rate: int = 16000, |
|
frame_length: float = 20.0, |
|
frame_space: float = 5.0, |
|
torch_dtype: str = "float32", |
|
**kwargs |
|
): |
|
super().__init__(torch_dtype=torch_dtype, **kwargs) |
|
self.sampling_rate = sampling_rate |
|
self.frame_length = frame_length |
|
self.frame_space = frame_space |
|
|
|
|
|
|
|
class ProsodyPreprocessor(FeatureExtractionMixin): |
|
config_class = ProsodyConfig |
|
|
|
def __init__(self, |
|
sampling_rate: int = 16000, |
|
frame_length: float = 20.0, |
|
frame_space: float = 5.0, |
|
torch_dtype: str = "float32", |
|
config: Optional[ProsodyConfig] = None, |
|
**kwargs): |
|
|
|
super().__init__() |
|
self.config = config |
|
self.speaker_stats: Dict[str, SpeakerStats] = {} |
|
self.sampling_rate = sampling_rate |
|
self.frame_length = frame_length |
|
self.frame_space = frame_space |
|
|
|
def extract_features(self, audio): |
|
"""Extract F0 and intensity features""" |
|
|
|
|
|
|
|
pYAAPT.PitchObj.interpolate = interpolate |
|
|
|
audio = torch.Tensor(audio) |
|
|
|
if audio.dim() == 1: |
|
audio = audio.unsqueeze(0) |
|
f0, f0_interp = self._get_f0(audio) |
|
f0 = f0[0, 0, :] |
|
f0_interpolated = f0_interp[0, 0, :] |
|
|
|
|
|
f0 = f0[6:] |
|
f0_interpolated = f0_interpolated[6:] |
|
|
|
sound = parselmouth.Sound(audio.numpy(), sampling_frequency=self.sampling_rate, start_time=0) |
|
|
|
|
|
|
|
intensity = sound.to_intensity(time_step=1/200.0) |
|
intensity_values = intensity.values.T.flatten() |
|
|
|
|
|
|
|
min_len = min(len(f0), len(intensity)) |
|
f0 = f0[:min_len] |
|
f0_interpolated = f0_interpolated[:min_len] |
|
intensity_values = intensity_values[:min_len] |
|
|
|
intensity_values[intensity_values < 20] = 20 |
|
|
|
|
|
return { |
|
"f0": f0, |
|
"f0_interp": f0_interpolated, |
|
"intensity": intensity_values, |
|
} |
|
|
|
def collect_stats(self, dataset: Dataset, num_proc: int = 4, batch_size: int = 32) -> Dict[str, SpeakerStats]: |
|
"""First pass: collect speaker statistics using dataset.map""" |
|
|
|
def extract_features_batch(examples): |
|
features_list = [] |
|
for audio in examples['audio']: |
|
features = self.extract_features(audio) |
|
features_list.append(features) |
|
|
|
return { |
|
'f0': [f['f0'] for f in features_list], |
|
'intensity': [f['intensity'] for f in features_list], |
|
'speaker_id': examples['speaker_id'] |
|
} |
|
|
|
features_dataset = dataset.map( |
|
extract_features_batch, |
|
batched=True, |
|
batch_size=batch_size, |
|
num_proc=num_proc, |
|
|
|
remove_columns=dataset.column_names |
|
) |
|
|
|
|
|
speaker_features = {} |
|
for item in features_dataset: |
|
|
|
speaker_id = item['speaker_id'] |
|
if speaker_id not in speaker_features: |
|
speaker_features[speaker_id] = {'f0': [], 'intensity': []} |
|
|
|
speaker_features[speaker_id]['f0'].append(item['f0']) |
|
speaker_features[speaker_id]['intensity'].append(item['intensity']) |
|
|
|
self.speaker_stats = { |
|
spk: SpeakerStats.from_features( |
|
feats['f0'], |
|
feats['intensity'] |
|
) |
|
for spk, feats in speaker_features.items() |
|
} |
|
|
|
return features_dataset, self.speaker_stats |
|
|
|
def save_stats(self, path: str): |
|
"""Save speaker stats to file""" |
|
stats_dict = { |
|
spk: dataclasses.asdict(stats) |
|
for spk, stats in self.speaker_stats.items() |
|
} |
|
torch.save(stats_dict, path) |
|
|
|
@classmethod |
|
def load_stats(cls, path: str) -> Dict[str, SpeakerStats]: |
|
"""Load speaker stats from file""" |
|
stats_dict = torch.load(path) |
|
return { |
|
spk: SpeakerStats(**stats) |
|
for spk, stats in stats_dict.items() |
|
} |
|
def _get_f0(self, audio: torch.Tensor): |
|
"""Extract F0 using YAAPT.""" |
|
to_pad = int(self.frame_length / 1000 * self.sampling_rate) // 2 |
|
|
|
f0s = [] |
|
f0s_interp = [] |
|
|
|
for y in audio.numpy().astype(np.float64): |
|
y_pad = np.pad(y.squeeze(), (to_pad, to_pad), "constant", constant_values=0) |
|
signal = basic.SignalObj(y_pad, self.sampling_rate) |
|
pitch = pYAAPT.yaapt( |
|
signal, |
|
frame_length=self.frame_length, |
|
frame_space=self.frame_space, |
|
nccf_thresh1=0.25, |
|
tda_frame_length=25.0 |
|
) |
|
f0s_interp.append(pitch.samp_interp[None, None, :]) |
|
f0s.append(pitch.samp_values[None, None, :]) |
|
|
|
f0 = np.vstack(f0s) |
|
f0_interp = np.vstack(f0s_interp) |
|
|
|
|
|
f0[f0 > 500] = 0 |
|
f0_interp[f0_interp > 500] = 0 |
|
f0[f0 < 0] = 0 |
|
f0_interp[f0_interp < 0] = 0 |
|
|
|
return f0, f0_interp |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def interpolate(self): |
|
pitch = np.zeros((self.nframes)) |
|
pitch[:] = self.samp_values |
|
pitch2 = medfilt(self.samp_values, self.SMOOTH_FACTOR) |
|
|
|
|
|
|
|
|
|
|
|
edges = self.edges_finder(pitch) |
|
first_sample = pitch[0] |
|
last_sample = pitch[-1] |
|
|
|
if len(np.nonzero(pitch2)[0]) < 2: |
|
pitch[pitch == 0] = self.PTCH_TYP |
|
else: |
|
nz_pitch = pitch2[pitch2 > 0] |
|
pitch2 = scipy_interp.pchip(np.nonzero(pitch2)[0], |
|
nz_pitch)(range(self.nframes)) |
|
pitch[pitch == 0] = pitch2[pitch == 0] |
|
if self.SMOOTH > 0: |
|
pitch = medfilt(pitch, self.SMOOTH_FACTOR) |
|
try: |
|
if first_sample == 0: |
|
|
|
if edges[0] == 0: |
|
edges[0] = 1 |
|
pitch[:edges[0]-1] = pitch[edges[0]] |
|
if last_sample == 0: |
|
pitch[edges[-1]+1:] = pitch[edges[-1]] |
|
except: |
|
pass |
|
self.samp_interp = pitch |
|
|