Higgs_Codec_Extended / boson_codeit.py
Respair's picture
Update boson_codeit.py
10c6e0f verified
import os
import sys
import json
import torch
import torch.nn.functional as F
import librosa
import numpy as np
from pathlib import Path
from tqdm import tqdm
import warnings
from torch.nn.utils import remove_weight_norm, weight_norm
import librosa
import torch
import torch.nn.functional as F
import numpy as np
import json
import torch
from higgs_audio_tokenizer import HiggsAudioTokenizer
import torch
import torch.nn as nn
import warnings
warnings.filterwarnings('ignore')
def remove_weight_norms_from_model(model):
for module in model.modules():
try:
remove_weight_norm(module)
except:
continue
return model
class EncodedResult:
def __init__(self, audio_codes):
self.audio_codes = audio_codes
def encode_batch(model, x_batch):
e_semantic_input = model.get_regress_target(x_batch).detach()
e_semantic = model.encoder_semantic(e_semantic_input.transpose(1, 2))
e_acoustic = model.encoder(x_batch)
if e_acoustic.shape[2] != e_semantic.shape[2]:
pad_size = 160 * model.semantic_downsample_factor
x_slice = x_batch[:, 0, :]
x_padded = F.pad(x_slice, (pad_size, pad_size))
e_acoustic = model.encoder(x_padded.unsqueeze(1))
min_len = min(e_acoustic.shape[2], e_semantic.shape[2])
e_acoustic = e_acoustic[:, :, :min_len]
e_semantic = e_semantic[:, :, :min_len]
e = torch.cat([e_acoustic, e_semantic], dim=1)
e = model.fc_prior(e.transpose(1, 2))
if model.quantizer_type == "RVQ":
e = e.transpose(1, 2)
_, codes, _, _ = model.quantizer(e, model.frame_rate, None)
codes = codes.permute(1, 0, 2)
else:
quantized, codes = model.quantizer(e)
codes = codes.permute(0, 2, 1)
return EncodedResult(audio_codes=codes)
def fix_all_inference_issues(model):
device = next(model.parameters()).device
model.eval()
with torch.no_grad():
for module in model.modules():
if isinstance(module, nn.Module):
module.eval()
if hasattr(module, 'training'):
module.training = False
if hasattr(model, 'semantic_model'):
print("Fixing semantic model...")
model.semantic_model = model.semantic_model.to(device)
model.semantic_model.eval()
def disable_gradient_checkpointing(module):
if hasattr(module, 'gradient_checkpointing'):
module.gradient_checkpointing = False
if hasattr(module, 'gradient_checkpointing_disable'):
try:
module.gradient_checkpointing_disable()
except:
pass
for child in module.children():
disable_gradient_checkpointing(child)
disable_gradient_checkpointing(model.semantic_model)
if hasattr(model.semantic_model, 'encoder'):
model.semantic_model.encoder.gradient_checkpointing = False
if hasattr(model.semantic_model.encoder, 'layers'):
for layer in model.semantic_model.encoder.layers:
if hasattr(layer, 'gradient_checkpointing'):
layer.gradient_checkpointing = False
def set_dropout_eval(module):
if isinstance(module, nn.Dropout):
module.eval()
module.training = False
for child in module.children():
set_dropout_eval(child)
set_dropout_eval(model)
torch.cuda.empty_cache() if torch.cuda.is_available() else None
return model
def inference_pipeline(checkpoint_path, config_path, device='cuda'):
print("Loading config...")
with open(config_path, 'r') as f:
config = json.load(f)
print("Creating model...")
model = HiggsAudioTokenizer(
n_filters=config['n_filters'],
D=config['D'],
target_bandwidths=config['target_bandwidths'],
ratios=config['ratios'],
sample_rate=config['sample_rate'],
bins=config['bins'],
n_q=config['n_q'],
codebook_dim=config.get('codebook_dim', None),
semantic_techer=config['semantic_techer'],
device=device
).to(device)
print("Loading checkpoint...")
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
if 'model_state_dict' in checkpoint:
state_dict = checkpoint['model_state_dict']
else:
state_dict = checkpoint
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith('module.'):
new_state_dict[k[7:]] = v
else:
new_state_dict[k] = v
model.load_state_dict(new_state_dict, strict=False)
print("Fixing inference issues...")
model = fix_all_inference_issues(model)
return model
warnings.filterwarnings("ignore")
OUTPUT_DIR = "/home/ubuntu/data_boson_44.1khz"
BATCH_SIZE = 32
SAMPLE_RATE = 44100
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DATASET_PATH = "/home/ubuntu/ttsar/Layla/src_bpe_2/Qanary_data"
print(f"Using device: {DEVICE}")
os.chdir("/home/ubuntu/ttsar/boson_audio_codec/audio_processing")
from datasets import load_from_disk
print(f"Loading dataset from: {DATASET_PATH}")
ds = load_from_disk(DATASET_PATH)
print(f"Dataset info: {ds}")
columns_to_remove = ['spk', 'duration', 'codes', 'input_ids', 'attention_mask']
existing_columns = [col for col in columns_to_remove if col in ds.column_names]
if existing_columns:
ds = ds.remove_columns(existing_columns)
df = ds.to_pandas()
print(f"Loaded {len(df)} files from dataset")
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"Output directory '{OUTPUT_DIR}' is ready.")
print("Checking for already processed files...")
def get_output_path(audio_path):
base_name = Path(audio_path).stem
return os.path.join(OUTPUT_DIR, f"{base_name}.pt")
original_count = len(df)
df['output_exists'] = df['filename'].apply(lambda x: os.path.exists(get_output_path(x)))
df_filtered = df[~df['output_exists']].copy()
skipped_count = original_count - len(df_filtered)
print(f"Found {skipped_count} already processed files. Skipping them.")
print(f"Processing {len(df_filtered)} remaining files.")
if len(df_filtered) == 0:
print("All files have already been processed!")
exit()
print("Loading Higgs Audio Tokenizer model...")
from transformers import HubertModel
from higgs_audio_tokenizer import HiggsAudioTokenizer
checkpoint_path = '/home/ubuntu/ttsar/boson_audio_codec/audio_processing/outputs_CQT/checkpoints/step_99000.pth'
config_path = '/home/ubuntu/ttsar/boson_audio_codec/audio_processing/config copy.json'
device = 'cuda'
model = inference_pipeline(checkpoint_path, config_path, device)
_ = model.eval()
model = remove_weight_norms_from_model(model)
print(f"Model loaded on {DEVICE}")
hop_length = model.hop_length
print(f"Encoder hop length: {hop_length}")
print(f"\nStarting batch processing with batch size {BATCH_SIZE}...")
filenames = df_filtered['filename'].tolist()
total_processed = 0
total_errors = 0
with torch.no_grad():
for batch_start in tqdm(range(0, len(filenames), BATCH_SIZE), desc="Processing batches"):
batch_end = min(batch_start + BATCH_SIZE, len(filenames))
batch_filenames = filenames[batch_start:batch_end]
batch_audio = []
batch_lengths = []
batch_outputs = []
for filename in batch_filenames:
output_path = get_output_path(filename)
if os.path.exists(output_path):
continue
try:
wav, _ = librosa.load(filename, sr=SAMPLE_RATE)
wav_tensor = torch.from_numpy(wav).float()
batch_audio.append(wav_tensor)
batch_lengths.append(len(wav))
batch_outputs.append(output_path)
except Exception as e:
print(f"\nError loading {filename}: {e}")
total_errors += 1
continue
if not batch_audio:
continue
max_len = max(len(x) for x in batch_audio)
padded_batch = []
for audio in batch_audio:
pad_len = max_len - len(audio)
if pad_len > 0:
audio = F.pad(audio, (0, pad_len), mode='constant', value=0)
padded_batch.append(audio)
batch_tensor = torch.stack(padded_batch, dim=0)
batch_tensor = batch_tensor.unsqueeze(1)
batch_tensor = batch_tensor.to(DEVICE)
try:
encoded = encode_batch(model, batch_tensor)
codes = encoded.audio_codes
for idx, (output_path, orig_len) in enumerate(zip(batch_outputs, batch_lengths)):
true_code_len = int(np.ceil(orig_len / hop_length))
item_codes = codes[idx, :, :true_code_len].cpu()
torch.save(item_codes, output_path)
total_processed += 1
except Exception as e:
print(f"\nError encoding batch: {e}")
total_errors += len(batch_outputs)
print("\n" + "="*50)
print("PROCESSING COMPLETE!")
print("="*50)
print(f"Successfully processed: {total_processed} files")
print(f"Previously processed: {skipped_count} files")
print(f"Errors encountered: {total_errors} files")
print(f"Output directory: {OUTPUT_DIR}")
final_count = len(list(Path(OUTPUT_DIR).glob("*.pt")))
print(f"Total .pt files in output: {final_count}")