|
|
|
import argparse, os, sys, pandas as pd, numpy as np, soundfile as sf, librosa, torch |
|
from datasets import Dataset |
|
from transformers import ( |
|
WhisperProcessor, WhisperForConditionalGeneration, |
|
Seq2SeqTrainingArguments, Seq2SeqTrainer |
|
) |
|
|
|
SAMPLE_RATE = 16000 |
|
DUR_SEC = 2.0 |
|
N_SAMPLES = int(SAMPLE_RATE * DUR_SEC) |
|
|
|
def load_2s(path: str) -> np.ndarray: |
|
try: |
|
audio, sr = sf.read(path, dtype="float32", always_2d=False) |
|
if audio.ndim > 1: |
|
audio = np.mean(audio, axis=1) |
|
except Exception: |
|
|
|
audio, sr = librosa.load(path, sr=None, mono=True) |
|
audio = audio.astype(np.float32) |
|
|
|
if sr != SAMPLE_RATE: |
|
audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLE_RATE).astype(np.float32) |
|
|
|
if len(audio) >= N_SAMPLES: |
|
audio = audio[:N_SAMPLES] |
|
else: |
|
pad = np.zeros(N_SAMPLES, dtype=np.float32) |
|
pad[: len(audio)] = audio |
|
audio = pad |
|
return np.ascontiguousarray(audio, dtype=np.float32) |
|
|
|
def make_transform(): |
|
|
|
def _tx(batch): |
|
paths = batch["path"] if isinstance(batch["path"], list) else [batch["path"]] |
|
texts = batch["text"] if isinstance(batch["text"], list) else [batch["text"]] |
|
|
|
audios = [load_2s(p) for p in paths] |
|
out = { |
|
"audio_2s": audios if len(audios) > 1 else audios[0], |
|
"text": texts if len(texts) > 1 else texts[0], |
|
"path": paths if len(paths) > 1 else paths[0], |
|
} |
|
return out |
|
return _tx |
|
|
|
class DataCollatorWhisper2s: |
|
def __init__(self, processor: WhisperProcessor): |
|
self.processor = processor |
|
self.tok = processor.tokenizer |
|
self.feat = processor.feature_extractor |
|
|
|
def __call__(self, features): |
|
|
|
audios = [f["audio_2s"] for f in features] |
|
texts = [f["text"] for f in features] |
|
|
|
|
|
|
|
inputs = self.feat(audios, sampling_rate=SAMPLE_RATE, return_tensors="pt") |
|
input_features = inputs.input_features |
|
|
|
|
|
labels = self.tok(texts, return_tensors="pt", padding=True).input_ids |
|
labels[labels == self.tok.pad_token_id] = -100 |
|
|
|
return {"input_features": input_features, "labels": labels} |
|
|
|
def main(): |
|
ap = argparse.ArgumentParser() |
|
ap.add_argument("--csv", required=True, help="CSV with columns: path,transcription OR path,text") |
|
ap.add_argument("--out_dir", required=True) |
|
ap.add_argument("--base_model", default="openai/whisper-small") |
|
ap.add_argument("--lang", default="en") |
|
ap.add_argument("--epochs", type=int, default=3) |
|
ap.add_argument("--batch", type=int, default=8) |
|
ap.add_argument("--num_workers", type=int, default=4) |
|
args = ap.parse_args() |
|
|
|
out_dir_abs = os.path.abspath(args.out_dir) |
|
os.makedirs(out_dir_abs, exist_ok=True) |
|
print(f"[INFO] Will save to: {out_dir_abs}") |
|
|
|
head = pd.read_csv(args.csv, nrows=1) |
|
df = pd.read_csv(args.csv) |
|
if "transcription" in head.columns: |
|
df = df.rename(columns={"transcription": "text"}) |
|
assert {"path","text"}.issubset(df.columns), "CSV must have columns path and transcription/text" |
|
df = df.dropna(subset=["path","text"]) |
|
df = df[(df["path"].astype(str).str.len()>0) & (df["text"].astype(str).str.len()>0)] |
|
print(f"[INFO] Rows: {len(df)}") |
|
|
|
|
|
ds = Dataset.from_pandas(df[["path","text"]], preserve_index=False) |
|
ds = ds.train_test_split(test_size=0.1, seed=42) |
|
train_ds, eval_ds = ds["train"], ds["test"] |
|
|
|
processor = WhisperProcessor.from_pretrained(args.base_model, language=args.lang, task="transcribe") |
|
model = WhisperForConditionalGeneration.from_pretrained(args.base_model) |
|
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language=args.lang, task="transcribe") |
|
model.config.suppress_tokens = [] |
|
model.gradient_checkpointing_enable() |
|
|
|
|
|
tx = make_transform() |
|
train_ds.set_transform(tx) |
|
eval_ds.set_transform(tx) |
|
|
|
collator = DataCollatorWhisper2s(processor) |
|
|
|
train_args = Seq2SeqTrainingArguments( |
|
output_dir=out_dir_abs, |
|
per_device_train_batch_size=args.batch, |
|
per_device_eval_batch_size=args.batch, |
|
dataloader_num_workers=args.num_workers, |
|
gradient_accumulation_steps=2, |
|
learning_rate=1e-5, |
|
num_train_epochs=args.epochs, |
|
evaluation_strategy="epoch", |
|
save_strategy="epoch", |
|
logging_strategy="steps", |
|
logging_steps=50, |
|
predict_with_generate=False, |
|
fp16=True, |
|
remove_unused_columns=False, |
|
report_to=[] |
|
) |
|
|
|
trainer = Seq2SeqTrainer( |
|
model=model, |
|
args=train_args, |
|
train_dataset=train_ds, |
|
eval_dataset=eval_ds, |
|
tokenizer=processor.tokenizer, |
|
data_collator=collator, |
|
) |
|
|
|
try: |
|
trainer.train() |
|
finally: |
|
print("[INFO] Forcing save...") |
|
trainer.save_model(out_dir_abs) |
|
model.save_pretrained(out_dir_abs) |
|
processor.save_pretrained(out_dir_abs) |
|
open(os.path.join(out_dir_abs, "SAVE_OK"), "w").write("ok\n") |
|
print(f"[DONE] Saved to: {out_dir_abs}") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|