|
|
|
|
|
|
|
import torch |
|
from datasets import load_dataset |
|
from transformers import AutoFeatureExtractor |
|
import os |
|
from tqdm import tqdm |
|
import matplotlib.pyplot as plt |
|
|
|
from pipeline_utils import compute_speaker_stats, plot_reconstruction |
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
dataset = load_dataset( |
|
"sanchit-gandhi/voxpopuli_dummy", |
|
|
|
split="validation" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
preprocessor = AutoFeatureExtractor.from_pretrained('MU-NLPC/F0_Energy_joint_VQVAE_embeddings-preprocessor', |
|
|
|
) |
|
|
|
processed_dataset = dataset.map( |
|
lambda x: preprocessor.extract_features(x['audio']['array']), |
|
load_from_cache_file=False, |
|
|
|
) |
|
|
|
processed_dataset.save_to_disk("processed_dataset") |
|
|
|
speaker_stats = compute_speaker_stats(processed_dataset) |
|
torch.save(speaker_stats, "speaker_stats.pt") |
|
|
|
|
|
from transformers import pipeline |
|
embedding_pipeline = pipeline(task="prosody-embedding", model="MU-NLPC/F0_Energy_joint_VQVAE_embeddings", |
|
f0_interp=False, |
|
f0_normalize=True, |
|
speaker_stats=speaker_stats, |
|
|
|
) |
|
|
|
|
|
results = processed_dataset.map( |
|
lambda x: embedding_pipeline(x), |
|
remove_columns=processed_dataset.column_names, |
|
load_from_cache_file=False |
|
|
|
) |
|
|
|
results.save_to_disk("embeddings_dataset") |
|
|
|
print(f"Processed {len(results)} samples") |
|
|
|
embedding_codebook = embedding_pipeline.model.vq.level_blocks[0].k |
|
print(f"embedding_codebook.shape", embedding_codebook.shape) |
|
|
|
embeddings_example = results[0]['codes'][0][0] |
|
print("Embeddings example:", embeddings_example) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
metrics_list = [result['metrics'] for result in results] |
|
avg_metrics = {} |
|
|
|
for metric in results[0]['metrics'].keys(): |
|
values = [m[metric] for m in metrics_list] |
|
avg_metrics[metric] = sum(values) / len(values) |
|
|
|
|
|
|
|
print("\nAverage metrics across dataset:") |
|
print(avg_metrics) |
|
|
|
|
|
print(f"Plotting reconstruction curves...") |
|
for i in tqdm(range(len(results))): |
|
fig = plot_reconstruction(results[i], i) |
|
os.makedirs('plots', exist_ok=True) |
|
plt.savefig(f'plots/reconstruction_sample{i}.png') |
|
plt.close() |
|
print(f"Done.") |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|
|
|
|
|