F0_Energy_joint_VQVAE_embeddings / generate_embeddings.py
Daporte's picture
Update generate_embeddings.py
471d1d7 verified
import torch
from datasets import load_dataset
from transformers import AutoFeatureExtractor
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
from pipeline_utils import compute_speaker_stats, plot_reconstruction
def main():
dataset = load_dataset(
"sanchit-gandhi/voxpopuli_dummy",
# "train",
split="validation"
)
#dataset = load_dataset(
# "mythicinfinity/libritts",
# "clean",
# split="test.clean",
# #trust_remote_code=True
#)
# dataset = load_dataset(
# "facebook/voxpopuli",
# "en",
# split="test"
# )
preprocessor = AutoFeatureExtractor.from_pretrained('MU-NLPC/F0_Energy_joint_VQVAE_embeddings-preprocessor',
#trust_remote_code=True
)
processed_dataset = dataset.map(
lambda x: preprocessor.extract_features(x['audio']['array']),
load_from_cache_file=False,
# num_proc=4
)
processed_dataset.save_to_disk("processed_dataset")
speaker_stats = compute_speaker_stats(processed_dataset)
torch.save(speaker_stats, "speaker_stats.pt")
from transformers import pipeline
embedding_pipeline = pipeline(task="prosody-embedding", model="MU-NLPC/F0_Energy_joint_VQVAE_embeddings",
f0_interp=False,
f0_normalize=True,
speaker_stats=speaker_stats,
#trust_remote_code=True
)
results = processed_dataset.map(
lambda x: embedding_pipeline(x),
remove_columns=processed_dataset.column_names,
load_from_cache_file=False
# num_proc=4
)
results.save_to_disk("embeddings_dataset")
print(f"Processed {len(results)} samples")
embedding_codebook = embedding_pipeline.model.vq.level_blocks[0].k
print(f"embedding_codebook.shape", embedding_codebook.shape)
embeddings_example = results[0]['codes'][0][0]
print("Embeddings example:", embeddings_example)
# inspect the embeddings in the codebook as follows
# code_point = embeddings_example[0]
# print(f"code_point", code_point)
# code_point_embedding = embedding_codebook[code_point]
# print(f"code_point_embedding", code_point_embedding)
# print(f"code_point_embedding.shape", code_point_embedding.shape)
# check that they are the same as the hidden states used in the model
# hidden_states = np.array(results[0]['hidden_states'])
# hidden_state = hidden_states[0, 0, :, 0]
# print(f"hidden_state", hidden_state)
metrics_list = [result['metrics'] for result in results]
avg_metrics = {}
for metric in results[0]['metrics'].keys():
values = [m[metric] for m in metrics_list]
avg_metrics[metric] = sum(values) / len(values)
# print(f"metric", metric)
# print(f"len(values)", len(values))
print("\nAverage metrics across dataset:")
print(avg_metrics)
print(f"Plotting reconstruction curves...")
for i in tqdm(range(len(results))):
fig = plot_reconstruction(results[i], i)
os.makedirs('plots', exist_ok=True)
plt.savefig(f'plots/reconstruction_sample{i}.png')
plt.close()
print(f"Done.")
if __name__ == '__main__':
main()