|
|
|
""" |
|
Script para testar classificação de sotaques em uma pasta de áudios. |
|
Este script carrega o modelo treinado e classifica todos os arquivos de áudio em uma pasta. |
|
""" |
|
|
|
import os |
|
import sys |
|
import argparse |
|
import glob |
|
import torch |
|
import librosa |
|
import numpy as np |
|
from pathlib import Path |
|
from transformers import ( |
|
AutoFeatureExtractor, |
|
AutoModelForAudioClassification |
|
) |
|
import pandas as pd |
|
from collections import Counter |
|
import seaborn as sns |
|
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score, f1_score |
|
|
|
def load_model(model_path): |
|
""" |
|
Carrega o modelo treinado e o feature extractor. |
|
""" |
|
print(f"Carregando modelo de: {model_path}") |
|
|
|
try: |
|
feature_extractor = AutoFeatureExtractor.from_pretrained(model_path) |
|
model = AutoModelForAudioClassification.from_pretrained(model_path) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
model.eval() |
|
|
|
print(f"Modelo carregado com sucesso! Usando device: {device}") |
|
return model, feature_extractor, device |
|
|
|
except Exception as e: |
|
print(f"Erro ao carregar modelo: {e}") |
|
return None, None, None |
|
|
|
def load_audio(file_path, target_sampling_rate=16000): |
|
""" |
|
Carrega um arquivo de áudio e redimensiona para a taxa de amostragem alvo. |
|
""" |
|
try: |
|
|
|
audio, sr = librosa.load(file_path, sr=target_sampling_rate, mono=True) |
|
return audio, sr |
|
except Exception as e: |
|
print(f"Erro ao carregar áudio {file_path}: {e}") |
|
return None, None |
|
|
|
def predict_audio(model, feature_extractor, device, audio_path): |
|
""" |
|
Classifica um único arquivo de áudio usando janela deslizante. |
|
""" |
|
|
|
audio, sr = load_audio(audio_path, feature_extractor.sampling_rate) |
|
if audio is None: |
|
return None, None, None, None |
|
|
|
try: |
|
|
|
window_size = int(sr * 5.0) |
|
overlap = int(sr * 2.5) |
|
|
|
|
|
if len(audio) <= window_size: |
|
predicted_label, confidence, class_id = predict_segment( |
|
model, feature_extractor, device, audio, sr |
|
) |
|
return predicted_label, confidence, class_id, 1 |
|
|
|
|
|
predictions_list = [] |
|
confidences_list = [] |
|
|
|
start = 0 |
|
segments_processed = 0 |
|
|
|
while start < len(audio): |
|
end = min(start + window_size, len(audio)) |
|
segment = audio[start:end] |
|
|
|
|
|
if len(segment) >= sr: |
|
pred_label, confidence, class_id = predict_segment( |
|
model, feature_extractor, device, segment, sr |
|
) |
|
if pred_label is not None: |
|
predictions_list.append(class_id) |
|
confidences_list.append(confidence) |
|
segments_processed += 1 |
|
|
|
|
|
start += window_size - overlap |
|
|
|
|
|
if end == len(audio): |
|
break |
|
|
|
if not predictions_list: |
|
return None, None, None, 0 |
|
|
|
|
|
predicted_label, final_confidence, predicted_class_id = combine_predictions( |
|
predictions_list, confidences_list |
|
) |
|
|
|
return predicted_label, final_confidence, predicted_class_id, segments_processed |
|
|
|
except Exception as e: |
|
print(f"Erro ao processar {audio_path}: {e}") |
|
return None, None, None, 0 |
|
|
|
def predict_segment(model, feature_extractor, device, audio_segment, sr): |
|
""" |
|
Classifica um segmento individual de áudio. |
|
""" |
|
try: |
|
|
|
inputs = feature_extractor( |
|
audio_segment, |
|
sampling_rate=sr, |
|
max_length=int(sr * 5.0), |
|
truncation=True, |
|
padding=True, |
|
return_tensors="pt" |
|
) |
|
|
|
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) |
|
predicted_class_id = predictions.argmax().item() |
|
confidence = predictions.max().item() |
|
|
|
|
|
label_map = {0: "pt_br", 1: "pt_pt"} |
|
predicted_label = label_map.get(predicted_class_id, "unknown") |
|
|
|
return predicted_label, confidence, predicted_class_id |
|
|
|
except Exception as e: |
|
return None, None, None |
|
|
|
def combine_predictions(predictions_list, confidences_list): |
|
""" |
|
Combina múltiplas predições usando voto majoritário ponderado pela confiança. |
|
""" |
|
|
|
predictions = np.array(predictions_list) |
|
confidences = np.array(confidences_list) |
|
|
|
|
|
class_scores = {} |
|
for class_id in [0, 1]: |
|
mask = predictions == class_id |
|
if np.any(mask): |
|
|
|
class_scores[class_id] = np.sum(confidences[mask]) |
|
else: |
|
class_scores[class_id] = 0.0 |
|
|
|
|
|
predicted_class_id = max(class_scores.keys(), key=lambda k: class_scores[k]) |
|
|
|
|
|
winner_mask = predictions == predicted_class_id |
|
if np.any(winner_mask): |
|
final_confidence = np.mean(confidences[winner_mask]) |
|
else: |
|
final_confidence = 0.0 |
|
|
|
|
|
label_map = {0: "pt_br", 1: "pt_pt"} |
|
predicted_label = label_map.get(predicted_class_id, "unknown") |
|
|
|
return predicted_label, final_confidence, predicted_class_id |
|
|
|
def test_folder(model_path, audio_folder, output_file=None, supported_formats=None): |
|
""" |
|
Testa todos os áudios em uma pasta usando janela deslizante. |
|
""" |
|
if supported_formats is None: |
|
supported_formats = ['.wav', '.mp3', '.flac', '.m4a', '.ogg'] |
|
|
|
print(f"Testando áudios na pasta: {audio_folder}") |
|
print("Usando janela deslizante de 5s com sobreposição de 2.5s para áudios longos") |
|
|
|
|
|
model, feature_extractor, device = load_model(model_path) |
|
if model is None: |
|
return |
|
|
|
|
|
audio_files = [] |
|
for ext in supported_formats: |
|
pattern = os.path.join(audio_folder, f"**/*{ext}") |
|
audio_files.extend(glob.glob(pattern, recursive=True)) |
|
|
|
if not audio_files: |
|
print(f"Nenhum arquivo de áudio encontrado na pasta {audio_folder}") |
|
print(f"Formatos suportados: {supported_formats}") |
|
return |
|
|
|
print(f"Encontrados {len(audio_files)} arquivos de áudio") |
|
|
|
|
|
results = [] |
|
total_segments = 0 |
|
|
|
|
|
for i, audio_file in enumerate(audio_files, 1): |
|
perc = (i / len(audio_files)) * 100 |
|
print(f"Processando {i}/{len(audio_files)} ({perc:.2f}%): {os.path.basename(audio_file)}") |
|
|
|
|
|
predicted_label, confidence, class_id, segments_used = predict_audio( |
|
model, feature_extractor, device, audio_file |
|
) |
|
|
|
if predicted_label is not None: |
|
|
|
try: |
|
audio_duration = librosa.get_duration(filename=audio_file) |
|
except: |
|
audio_duration = None |
|
|
|
result = { |
|
'arquivo': os.path.basename(audio_file), |
|
'caminho_completo': audio_file, |
|
'predição': predicted_label, |
|
'confiança': confidence, |
|
'classe_id': class_id, |
|
'duração_segundos': audio_duration, |
|
'segmentos_analisados': segments_used |
|
} |
|
results.append(result) |
|
total_segments += segments_used |
|
|
|
|
|
if segments_used > 1: |
|
print(f" -> {predicted_label} (confiança: {confidence:.3f}) [{segments_used} segmentos]") |
|
else: |
|
print(f" -> {predicted_label} (confiança: {confidence:.3f})") |
|
else: |
|
print(f" -> Erro ao processar arquivo") |
|
|
|
|
|
if results: |
|
df = pd.DataFrame(results) |
|
|
|
|
|
print(f"\n=== Resumo dos Resultados ===") |
|
print(f"Total de arquivos processados: {len(results)}") |
|
print(f"Arquivos com erro: {len(audio_files) - len(results)}") |
|
print(f"Total de segmentos analisados: {total_segments}") |
|
|
|
|
|
if 'segmentos_analisados' in df.columns: |
|
avg_segments = df['segmentos_analisados'].mean() |
|
max_segments = df['segmentos_analisados'].max() |
|
multi_segment_files = len(df[df['segmentos_analisados'] > 1]) |
|
print(f"Segmentos por arquivo (média): {avg_segments:.1f}") |
|
print(f"Máximo de segmentos: {max_segments}") |
|
print(f"Arquivos com múltiplos segmentos: {multi_segment_files}") |
|
|
|
|
|
if 'duração_segundos' in df.columns and df['duração_segundos'].notna().any(): |
|
avg_duration = df['duração_segundos'].mean() |
|
max_duration = df['duração_segundos'].max() |
|
print(f"Duração média dos áudios: {avg_duration:.1f}s") |
|
print(f"Duração máxima: {max_duration:.1f}s") |
|
|
|
|
|
print(f"\nDistribuição das predições:") |
|
distribution = Counter(df['predição']) |
|
for label, count in distribution.items(): |
|
percentage = (count / len(results)) * 100 |
|
print(f" {label}: {count} arquivos ({percentage:.1f}%)") |
|
|
|
|
|
avg_confidence = df['confiança'].mean() |
|
print(f"\nConfiança média: {avg_confidence:.3f}") |
|
|
|
|
|
low_confidence = df[df['confiança'] < 0.7] |
|
if not low_confidence.empty: |
|
print(f"\nArquivos com baixa confiança (< 0.7): {len(low_confidence)}") |
|
for _, row in low_confidence.iterrows(): |
|
segments_info = f" [{row.get('segmentos_analisados', 1)} seg]" if row.get('segmentos_analisados', 1) > 1 else "" |
|
print(f" {row['arquivo']}: {row['predição']} ({row['confiança']:.3f}){segments_info}") |
|
|
|
|
|
if output_file: |
|
df.to_csv(output_file, index=False) |
|
print(f"\nResultados salvos em: {output_file}") |
|
|
|
|
|
if results is not None and len([r for r in results if 'pt_br' in r.get('caminho_completo', '') or 'pt_pt' in r.get('caminho_completo', '')]) > 0: |
|
metrics_file = output_file.replace('.csv', '_metrics.txt') |
|
with open(metrics_file, 'w', encoding='utf-8') as f: |
|
f.write("=== MÉTRICAS DE CLASSIFICAÇÃO ===\n\n") |
|
f.write(f"Total de arquivos processados: {len(results)}\n") |
|
f.write(f"Confiança média: {avg_confidence:.3f}\n") |
|
f.write(f"Total de segmentos analisados: {total_segments}\n\n") |
|
f.write("Distribuição das predições:\n") |
|
for label, count in distribution.items(): |
|
percentage = (count / len(results)) * 100 |
|
f.write(f" {label}: {count} arquivos ({percentage:.1f}%)\n") |
|
print(f"Métricas básicas salvas em: {metrics_file}") |
|
|
|
return df |
|
else: |
|
print("Nenhum arquivo foi processado com sucesso.") |
|
return None |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser( |
|
description="Testa classificação de sotaques em uma pasta de áudios" |
|
) |
|
parser.add_argument( |
|
"audio_folder", |
|
help="Pasta contendo os arquivos de áudio para teste" |
|
) |
|
parser.add_argument( |
|
"--model_path", |
|
default="./nn/results/final_model", |
|
help="Caminho para o modelo treinado (default: ./nn/results/final_model)" |
|
) |
|
parser.add_argument( |
|
"--output", |
|
help="Arquivo CSV para salvar os resultados (opcional)" |
|
) |
|
parser.add_argument( |
|
"--formats", |
|
nargs="+", |
|
default=['.wav', '.mp3', '.flac', '.m4a', '.ogg'], |
|
help="Formatos de áudio suportados (default: .wav .mp3 .flac .m4a .ogg)" |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if not os.path.exists(args.audio_folder): |
|
print(f"Erro: Pasta '{args.audio_folder}' não encontrada!") |
|
sys.exit(1) |
|
|
|
|
|
if not os.path.exists(args.model_path): |
|
print(f"Erro: Modelo '{args.model_path}' não encontrado!") |
|
sys.exit(1) |
|
|
|
|
|
results = test_folder( |
|
model_path=args.model_path, |
|
audio_folder=args.audio_folder, |
|
output_file=args.output, |
|
supported_formats=args.formats |
|
) |
|
|
|
|
|
if results is not None and not results.empty: |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
|
|
true_labels = [] |
|
pred_labels = results['predição'].tolist() |
|
|
|
for idx, row in results.iterrows(): |
|
arquivo = row['arquivo'] |
|
caminho_completo = row['caminho_completo'] |
|
|
|
|
|
if '/pt_br/' in caminho_completo or caminho_completo.endswith('/pt_br') or '\\pt_br\\' in caminho_completo or caminho_completo.endswith('\\pt_br'): |
|
true_labels.append('pt_br') |
|
elif '/pt_pt/' in caminho_completo or caminho_completo.endswith('/pt_pt') or '\\pt_pt\\' in caminho_completo or caminho_completo.endswith('\\pt_pt'): |
|
true_labels.append('pt_pt') |
|
|
|
elif 'pt_br' in arquivo.lower() or 'brasil' in arquivo.lower(): |
|
true_labels.append('pt_br') |
|
elif 'pt_pt' in arquivo.lower() or 'portugal' in arquivo.lower(): |
|
true_labels.append('pt_pt') |
|
else: |
|
|
|
true_labels.append('unknown') |
|
|
|
|
|
known_mask = [label != 'unknown' for label in true_labels] |
|
known_true = [true_labels[i] for i in range(len(true_labels)) if known_mask[i]] |
|
known_pred = [pred_labels[i] for i in range(len(pred_labels)) if known_mask[i]] |
|
|
|
if len(known_true) > 0: |
|
labels = ['pt_br', 'pt_pt'] |
|
cm = confusion_matrix(known_true, known_pred, labels=labels) |
|
|
|
|
|
unknown_count = len(true_labels) - len(known_true) |
|
accuracy = accuracy_score(known_true, known_pred) |
|
f1_macro = f1_score(known_true, known_pred, labels=labels, average='macro') |
|
f1_weighted = f1_score(known_true, known_pred, labels=labels, average='weighted') |
|
f1_pt_br = f1_score(known_true, known_pred, labels=labels, pos_label='pt_br', average='binary') if 'pt_br' in labels else 0 |
|
f1_pt_pt = f1_score(known_true, known_pred, labels=labels, pos_label='pt_pt', average='binary') if 'pt_pt' in labels else 0 |
|
|
|
print(f"\nEstatísticas da Matriz de Confusão:") |
|
print(f"Arquivos com labels conhecidos: {len(known_true)}") |
|
print(f"Arquivos com labels desconhecidos: {unknown_count}") |
|
print(f"Acurácia: {accuracy:.3f} ({accuracy*100:.1f}%)") |
|
print(f"F1-Score Macro: {f1_macro:.3f} ({f1_macro*100:.1f}%)") |
|
print(f"F1-Score Ponderado: {f1_weighted:.3f} ({f1_weighted*100:.1f}%)") |
|
print(f"F1-Score PT-BR: {f1_pt_br:.3f} ({f1_pt_br*100:.1f}%)") |
|
print(f"F1-Score PT-PT: {f1_pt_pt:.3f} ({f1_pt_pt*100:.1f}%)") |
|
|
|
|
|
print(f"\nRelatório de Classificação:") |
|
print(classification_report(known_true, known_pred, labels=labels, zero_division=0)) |
|
|
|
if unknown_count > 0: |
|
print(f"\nArquivos ignorados (sem label inferível):") |
|
for i, (true_label, arquivo) in enumerate(zip(true_labels, results['arquivo'])): |
|
if true_label == 'unknown': |
|
print(f" {arquivo}") |
|
|
|
|
|
errors = [] |
|
for i, (true_label, pred_label, arquivo) in enumerate(zip(known_true, known_pred, |
|
[results.iloc[j]['arquivo'] for j in range(len(results)) if known_mask[j]])): |
|
if true_label != pred_label: |
|
confidence = [results.iloc[j]['confiança'] for j in range(len(results)) if known_mask[j]][i] |
|
errors.append({ |
|
'arquivo': arquivo, |
|
'verdadeiro': true_label, |
|
'predito': pred_label, |
|
'confianca': confidence |
|
}) |
|
|
|
if errors: |
|
print(f"\nErros de Classificação ({len(errors)} arquivos):") |
|
for error in errors: |
|
print(f" {error['arquivo']}: {error['verdadeiro']} → {error['predito']} (conf: {error['confianca']:.3f})") |
|
else: |
|
print(f"\n✓ Nenhum erro de classificação encontrado!") |
|
|
|
|
|
plt.figure(figsize=(8, 6)) |
|
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', |
|
xticklabels=labels, yticklabels=labels) |
|
plt.title(f'Matriz de Confusão - Classificação de Sotaques\n({len(known_true)} arquivos, Acc: {accuracy:.1%}, F1-Macro: {f1_macro:.1%})') |
|
plt.xlabel('Predição') |
|
plt.ylabel('Verdadeiro') |
|
plt.tight_layout() |
|
|
|
|
|
confusion_matrix_path = args.output.replace('.csv', '_confusion_matrix.png') if args.output else 'confusion_matrix.png' |
|
plt.savefig(confusion_matrix_path, dpi=300, bbox_inches='tight') |
|
print(f"\nMatriz de confusão salva em: {confusion_matrix_path}") |
|
|
|
|
|
if args.output: |
|
detailed_metrics_file = args.output.replace('.csv', '_detailed_metrics.txt') |
|
with open(detailed_metrics_file, 'w', encoding='utf-8') as f: |
|
f.write("=== MÉTRICAS DETALHADAS DE CLASSIFICAÇÃO ===\n\n") |
|
f.write(f"Total de arquivos processados: {len(results)}\n") |
|
f.write(f"Arquivos com labels conhecidos: {len(known_true)}\n") |
|
f.write(f"Arquivos com labels desconhecidos: {unknown_count}\n") |
|
f.write(f"Total de segmentos analisados: {total_segments}\n\n") |
|
|
|
f.write("=== MÉTRICAS DE PERFORMANCE ===\n") |
|
f.write(f"Acurácia: {accuracy:.3f} ({accuracy*100:.1f}%)\n") |
|
f.write(f"F1-Score Macro: {f1_macro:.3f} ({f1_macro*100:.1f}%)\n") |
|
f.write(f"F1-Score Ponderado: {f1_weighted:.3f} ({f1_weighted*100:.1f}%)\n") |
|
f.write(f"F1-Score PT-BR: {f1_pt_br:.3f} ({f1_pt_br*100:.1f}%)\n") |
|
f.write(f"F1-Score PT-PT: {f1_pt_pt:.3f} ({f1_pt_pt*100:.1f}%)\n\n") |
|
|
|
f.write("=== RELATÓRIO DE CLASSIFICAÇÃO ===\n") |
|
f.write(classification_report(known_true, known_pred, labels=labels, zero_division=0)) |
|
f.write("\n\n=== MATRIZ DE CONFUSÃO ===\n") |
|
f.write(f" Predito\n") |
|
f.write(f" pt_br pt_pt\n") |
|
f.write(f"Real\n") |
|
f.write(f"pt_br {cm[0][0]:4d} {cm[0][1]:4d}\n") |
|
f.write(f"pt_pt {cm[1][0]:4d} {cm[1][1]:4d}\n\n") |
|
|
|
if errors: |
|
f.write(f"=== ERROS DE CLASSIFICAÇÃO ({len(errors)} arquivos) ===\n") |
|
for error in errors: |
|
f.write(f"{error['arquivo']}: {error['verdadeiro']} → {error['predito']} (conf: {error['confianca']:.3f})\n") |
|
else: |
|
f.write("=== ERROS DE CLASSIFICAÇÃO ===\n") |
|
f.write("Nenhum erro de classificação encontrado!\n") |
|
|
|
print(f"Métricas detalhadas salvas em: {detailed_metrics_file}") |
|
|
|
plt.show() |
|
else: |
|
print(f"\n⚠️ Não foi possível criar matriz de confusão:") |
|
print(f"Nenhum arquivo tinha label inferível do caminho ou nome.") |
|
print(f"Para usar a matriz de confusão, organize os arquivos em pastas 'pt_br' e 'pt_pt'") |
|
print(f"ou garanta que os nomes dos arquivos contenham essas strings.") |
|
|
|
if __name__ == "__main__": |
|
|
|
if len(sys.argv) == 1: |
|
print("=== Script de Teste de Classificação de Sotaques ===") |
|
print() |
|
print("Este script testa um modelo treinado em uma pasta de áudios usando janela deslizante.") |
|
print() |
|
print("Uso:") |
|
print(" python test_audio_folder.py <pasta_de_audios>") |
|
print() |
|
print("Exemplos:") |
|
print(" python test_audio_folder.py ./audios_teste") |
|
print(" python test_audio_folder.py ./audios_teste --output resultados.csv") |
|
print(" python test_audio_folder.py ./audios_teste --model_path ./results/checkpoint-20000") |
|
print() |
|
print("Parâmetros:") |
|
print(" pasta_de_audios : Pasta com arquivos de áudio para classificar") |
|
print(" --model_path : Caminho do modelo treinado (default: ./results/final_model)") |
|
print(" --output : Arquivo CSV para salvar resultados (opcional)") |
|
print(" --formats : Formatos suportados (default: .wav .mp3 .flac .m4a .ogg)") |
|
print() |
|
print("Funcionalidades da Janela Deslizante:") |
|
print("- Janelas de 5 segundos com sobreposição de 2.5s") |
|
print("- Áudios curtos: classificação direta") |
|
print("- Áudios longos: múltiplos segmentos combinados") |
|
print("- Resultado final: voto majoritário ponderado por confiança") |
|
print() |
|
print("O script irá:") |
|
print("1. Carregar o modelo treinado") |
|
print("2. Encontrar todos os arquivos de áudio na pasta") |
|
print("3. Classificar cada áudio usando janela deslizante") |
|
print("4. Mostrar estatísticas detalhadas dos resultados") |
|
print("5. Salvar resultados em CSV (se especificado)") |
|
print("6. Gerar matriz de confusão (se possível inferir labels)") |
|
sys.exit(0) |
|
|
|
main() |
|
|