whisper-vm-finetune / scripts /finetune_whisper_small.py
akhanriz's picture
initial commit
e112632 verified
# scripts/finetune_whisper_small.py
import argparse, os, sys, pandas as pd, numpy as np, soundfile as sf, librosa, torch
from datasets import Dataset
from transformers import (
WhisperProcessor, WhisperForConditionalGeneration,
Seq2SeqTrainingArguments, Seq2SeqTrainer
)
SAMPLE_RATE = 16000
DUR_SEC = 2.0
N_SAMPLES = int(SAMPLE_RATE * DUR_SEC)
def load_2s(path: str) -> np.ndarray:
try:
audio, sr = sf.read(path, dtype="float32", always_2d=False)
if audio.ndim > 1:
audio = np.mean(audio, axis=1)
except Exception:
# fallback if soundfile cannot read
audio, sr = librosa.load(path, sr=None, mono=True)
audio = audio.astype(np.float32)
if sr != SAMPLE_RATE:
audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLE_RATE).astype(np.float32)
if len(audio) >= N_SAMPLES:
audio = audio[:N_SAMPLES]
else:
pad = np.zeros(N_SAMPLES, dtype=np.float32)
pad[: len(audio)] = audio
audio = pad
return np.ascontiguousarray(audio, dtype=np.float32)
def make_transform():
# set_transform can be called on single example or batch; handle both
def _tx(batch):
paths = batch["path"] if isinstance(batch["path"], list) else [batch["path"]]
texts = batch["text"] if isinstance(batch["text"], list) else [batch["text"]]
# we only attach raw 2s audio; features are made in data_collator
audios = [load_2s(p) for p in paths]
out = {
"audio_2s": audios if len(audios) > 1 else audios[0],
"text": texts if len(texts) > 1 else texts[0],
"path": paths if len(paths) > 1 else paths[0],
}
return out
return _tx
class DataCollatorWhisper2s:
def __init__(self, processor: WhisperProcessor):
self.processor = processor
self.tok = processor.tokenizer
self.feat = processor.feature_extractor
def __call__(self, features):
# features contain: {"audio_2s": np.array, "text": str, "path": str}
audios = [f["audio_2s"] for f in features]
texts = [f["text"] for f in features]
# Whisper input features (80x3000), computed on-the-fly
# processor.feature_extractor already pads to fixed shape
inputs = self.feat(audios, sampling_rate=SAMPLE_RATE, return_tensors="pt")
input_features = inputs.input_features # [B, 80, 3000]
# Tokenize targets; pad, then replace pad_token_id with -100
labels = self.tok(texts, return_tensors="pt", padding=True).input_ids
labels[labels == self.tok.pad_token_id] = -100
return {"input_features": input_features, "labels": labels}
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--csv", required=True, help="CSV with columns: path,transcription OR path,text")
ap.add_argument("--out_dir", required=True)
ap.add_argument("--base_model", default="openai/whisper-small")
ap.add_argument("--lang", default="en")
ap.add_argument("--epochs", type=int, default=3)
ap.add_argument("--batch", type=int, default=8)
ap.add_argument("--num_workers", type=int, default=4)
args = ap.parse_args()
out_dir_abs = os.path.abspath(args.out_dir)
os.makedirs(out_dir_abs, exist_ok=True)
print(f"[INFO] Will save to: {out_dir_abs}")
head = pd.read_csv(args.csv, nrows=1)
df = pd.read_csv(args.csv)
if "transcription" in head.columns:
df = df.rename(columns={"transcription": "text"})
assert {"path","text"}.issubset(df.columns), "CSV must have columns path and transcription/text"
df = df.dropna(subset=["path","text"])
df = df[(df["path"].astype(str).str.len()>0) & (df["text"].astype(str).str.len()>0)]
print(f"[INFO] Rows: {len(df)}")
# Tiny, path-only dataset; we transform lazily
ds = Dataset.from_pandas(df[["path","text"]], preserve_index=False)
ds = ds.train_test_split(test_size=0.1, seed=42)
train_ds, eval_ds = ds["train"], ds["test"]
processor = WhisperProcessor.from_pretrained(args.base_model, language=args.lang, task="transcribe")
model = WhisperForConditionalGeneration.from_pretrained(args.base_model)
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language=args.lang, task="transcribe")
model.config.suppress_tokens = [] # better for short clips
model.gradient_checkpointing_enable() # less VRAM
# lazy transforms (no big arrays stored)
tx = make_transform()
train_ds.set_transform(tx)
eval_ds.set_transform(tx)
collator = DataCollatorWhisper2s(processor)
train_args = Seq2SeqTrainingArguments(
output_dir=out_dir_abs,
per_device_train_batch_size=args.batch,
per_device_eval_batch_size=args.batch,
dataloader_num_workers=args.num_workers,
gradient_accumulation_steps=2,
learning_rate=1e-5,
num_train_epochs=args.epochs,
evaluation_strategy="epoch",
save_strategy="epoch",
logging_strategy="steps",
logging_steps=50,
predict_with_generate=False,
fp16=True,
remove_unused_columns=False, # keep our custom fields
report_to=[]
)
trainer = Seq2SeqTrainer(
model=model,
args=train_args,
train_dataset=train_ds,
eval_dataset=eval_ds,
tokenizer=processor.tokenizer,
data_collator=collator,
)
try:
trainer.train()
finally:
print("[INFO] Forcing save...")
trainer.save_model(out_dir_abs)
model.save_pretrained(out_dir_abs)
processor.save_pretrained(out_dir_abs)
open(os.path.join(out_dir_abs, "SAVE_OK"), "w").write("ok\n")
print(f"[DONE] Saved to: {out_dir_abs}")
if __name__ == "__main__":
main()