File size: 8,874 Bytes
410ae2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a65a2e6
d059b0e
410ae2c
 
 
 
 
 
 
 
 
d059b0e
 
410ae2c
 
a65a2e6
 
 
 
 
 
 
410ae2c
 
 
 
a65a2e6
410ae2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a65a2e6
410ae2c
 
 
 
 
 
 
 
 
 
 
 
 
a65a2e6
 
 
 
 
 
 
 
 
 
410ae2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a65a2e6
 
 
 
 
 
 
410ae2c
 
 
a65a2e6
410ae2c
 
 
 
 
 
a65a2e6
 
410ae2c
 
 
 
 
 
 
 
 
 
 
 
 
9b2188a
410ae2c
a65a2e6
410ae2c
 
 
a65a2e6
 
410ae2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a65a2e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
"""
CUPE: Easy usage with automatic downloading from Hugging Face Hub
"""

import torch
import torchaudio
from huggingface_hub import hf_hub_download
import importlib.util
import sys
import os

def load_cupe_model(model_name="english", device="auto"):
    """
    Load CUPE model with automatic downloading from Hugging Face Hub
    
    Args:
        model_name: "english", "multilingual-mls", or "multilingual-mswc"
        device: "auto", "cpu", or "cuda"
    
    Returns:
        Tuple of (extractor, windowing_module)
    """
    
    # Model checkpoint mapping
    model_files = {
        "english": "en_libri1000_uj01d_e199_val_GER=0.2307.ckpt",
        "multilingual-mls": "multi_MLS8_uh02_e36_val_GER=0.2334.ckpt",
        "multilingual-mswc": "multi_mswc38_ug20_e59_val_GER=0.5611.ckpt"
    }
    
    if model_name not in model_files:
        raise ValueError(f"Model {model_name} not available. Choose from: {list(model_files.keys())}")
    
    if device == "auto":
        device = "cuda" if torch.cuda.is_available() else "cpu"
    
    print(f"Loading CUPE {model_name} model...")
    
    # Download model files from Hugging Face Hub
    repo_id = "Tabahi/CUPE-2i"
    
    model_file = hf_hub_download(repo_id=repo_id, filename="model2i.py")
    windowing_file = hf_hub_download(repo_id=repo_id, filename="windowing.py")
    mapper_file = hf_hub_download(repo_id=repo_id, filename="mapper.py")
    model_utils_file = hf_hub_download(repo_id=repo_id, filename="model_utils.py")
    checkpoint_file = hf_hub_download(repo_id=repo_id, filename=f"ckpt/{model_files[model_name]}")
    
    # Dynamically import the modules
    def import_module_from_file(module_name, file_path):
        spec = importlib.util.spec_from_file_location(module_name, file_path)
        module = importlib.util.module_from_spec(spec)
        sys.modules[module_name] = module
        spec.loader.exec_module(module)
        return module

    _ = import_module_from_file("model_utils", model_utils_file)
    model2i = import_module_from_file("model2i", model_file)
    windowing = import_module_from_file("windowing", windowing_file)
    mapper = import_module_from_file("mapper", mapper_file)

    phoneme_to_token = mapper.phoneme_mapped_index
    token_to_phoneme = {v: k for k, v in phoneme_to_token.items()}
    group_to_token = mapper.phoneme_groups_index
    token_to_group = {v: k for k, v in group_to_token.items()}

    # Initialize the model
    extractor = model2i.CUPEEmbeddingsExtractor(checkpoint_file, device=device)
    
    print(f"Model loaded on {device}")
    return extractor, windowing, token_to_phoneme, token_to_group

def predict_phonemes(audio_path, model_name="english", device="auto"):
    """
    Predict phonemes from audio file
    
    Args:
        audio_path: Path to audio file
        model_name: CUPE model variant to use
        device: Device to run inference on
    
    Returns:
        Dictionary with predictions and metadata
    """
    
    # Load model
    extractor, windowing, token_to_phoneme, token_to_group = load_cupe_model(model_name, device)
    
    # Audio processing parameters
    sample_rate = 16000
    window_size_ms = 120
    stride_ms = 80
    
    # Load and preprocess audio
    audio, orig_sr = torchaudio.load(audio_path)
    
    # Convert to mono if stereo
    if audio.shape[0] > 1:
        audio = audio.mean(dim=0, keepdim=True)
    
    resampler = torchaudio.transforms.Resample(
            sample_rate,
            lowpass_filter_width=64,
            rolloff=0.9475937167399596,
            resampling_method="sinc_interp_kaiser",
            beta=14.769656459379492,
        )
    
    # Always use resampler for consistency
    audio = resampler(audio)
    
    # Move to device and add batch dimension
    audio = audio.to(device)
    audio_batch = audio.unsqueeze(0)
    
    print(f"Processing audio: {audio.shape[1]/sample_rate:.2f}s duration")
    
    # Window the audio
    windowed_audio = windowing.slice_windows(
        audio_batch,
        sample_rate,
        window_size_ms,
        stride_ms
    )
    
    batch_size, num_windows, window_size = windowed_audio.shape
    windows_flat = windowed_audio.reshape(-1, window_size)
    
    # Get model predictions
    logits_phonemes, logits_groups = extractor.predict(
        windows_flat, 
        return_embeddings=False, 
        groups_only=False
    )
    
    # Reshape and stitch predictions
    frames_per_window = logits_phonemes.shape[1]
    
    logits_phonemes = logits_phonemes.reshape(batch_size, num_windows, frames_per_window, -1)
    logits_groups = logits_groups.reshape(batch_size, num_windows, frames_per_window, -1)
    
    phoneme_logits = windowing.stich_window_predictions(
        logits_phonemes,
        original_audio_length=audio_batch.size(2),
        cnn_output_size=frames_per_window,
        sample_rate=sample_rate,
        window_size_ms=window_size_ms,
        stride_ms=stride_ms
    )
    
    group_logits = windowing.stich_window_predictions(
        logits_groups,
        original_audio_length=audio_batch.size(2),
        cnn_output_size=frames_per_window,
        sample_rate=sample_rate,
        window_size_ms=window_size_ms,
        stride_ms=stride_ms
    )
    
    # Convert to probabilities and predictions
    phoneme_probs = torch.softmax(phoneme_logits.squeeze(0), dim=-1)
    group_probs = torch.softmax(group_logits.squeeze(0), dim=-1)
    
    phoneme_preds = torch.argmax(phoneme_probs, dim=-1)
    group_preds = torch.argmax(group_probs, dim=-1)
    
    phonemes_sequence = [token_to_phoneme[int(p)] for p in phoneme_preds.cpu().numpy()]
    groups_sequence = [token_to_group[int(g)] for g in group_preds.cpu().numpy()]
    # remove noise
    phonemes_sequence = [p for p in phonemes_sequence if p != 'noise']
    groups_sequence = [g for g in groups_sequence if g != 'noise']
    
    
    # Calculate timestamps (approximately 16ms per frame)
    num_frames = phoneme_probs.shape[0]
    
    print(f"Processed {num_frames} frames ({num_frames*16}ms total)")
    
    return {
        'phoneme_probabilities': phoneme_probs.cpu().numpy(),
        'phoneme_predictions': phoneme_preds.cpu().numpy(),
        'group_probabilities': group_probs.cpu().numpy(), 
        'group_predictions': group_preds.cpu().numpy(),
        'phonemes_sequence': phonemes_sequence,
        'groups_sequence': groups_sequence,
        'model_info': {
            'model_name': model_name,
            'sample_rate': sample_rate,
            'frames_per_second': 1000/16,  # ~62.5 fps
            'num_phoneme_classes': phoneme_probs.shape[-1],
            'num_group_classes': group_probs.shape[-1]
        }
    }

# Example usage
if __name__ == "__main__":
    
    # Simple example
    audio_file = "samples/109867__timkahn__butterfly.wav.wav"  # Replace with your audio file
    
    
    if not os.path.exists(audio_file):
        print(f"Audio file {audio_file} does not exist. Please provide a valid path.")
        sys.exit(1)
    
    torch.manual_seed(42)  # For reproducibility
    # Predict with English model
    results = predict_phonemes(
        audio_path=audio_file,
        model_name="english",  # or "multilingual-mls" or "multilingual-mswc"
        device="cpu"
    )
    
    print(f"\nResults:")
    print(f"Phoneme predictions shape: {results['phoneme_predictions'].shape}")
    print(f"Group predictions shape: {results['group_predictions'].shape}")
    print(f"Model info: {results['model_info']}")
    
    # Show first 10 predictions with timestamps
    print(f"\nFirst 10 frame predictions:")
    for i in range(min(10, len(results['phoneme_predictions']))):
        print(f"Frame {i}: phoneme={results['phoneme_predictions'][i]}, "
              f"group={results['group_predictions'][i]}")

    print(f"\nPhonemes sequence: {results['phonemes_sequence'][:10]}...")  # Show first 10 phonemes
    print(f"Groups sequence: {results['groups_sequence'][:10]}...")

''' output:
Loading CUPE english model...
Model loaded on cpu
Processing audio: 1.26s duration
Processed 75 frames (1200ms total)

Results:
Phoneme predictions shape: (75,)
Group predictions shape: (75,)
Model info: {'model_name': 'english', 'sample_rate': 16000, 'frames_per_second': 62.5, 'num_phoneme_classes': 67, 'num_group_classes': 17}

First 10 frame predictions:
Frame 0: phoneme=66, group=16
Frame 1: phoneme=66, group=16
Frame 2: phoneme=29, group=7
Frame 3: phoneme=66, group=16
Frame 4: phoneme=66, group=16
Frame 5: phoneme=66, group=16
Frame 6: phoneme=10, group=2
Frame 7: phoneme=66, group=16
Frame 8: phoneme=66, group=16
Frame 9: phoneme=66, group=16

Phonemes sequence: ['b', 'ʌ', 't', 'h', 'ʌ', 'f', 'l', 'æ']...
Groups sequence: ['voiced_stops', 'central_vowels', 'voiceless_stops', 'voiceless_fricatives', 'central_vowels', 'voiceless_fricatives', 'laterals', 'low_vowels']...
'''